From fc5d705b83fef6f89c7c090903a3ff6fbd7fd0dd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 27 Aug 2023 10:07:01 -0700 Subject: [PATCH] dynamic_size -> dynamic_img_size, add dynamic_img_pad for padding option --- timm/layers/patch_embed.py | 10 ++++++++-- timm/models/deit.py | 2 +- timm/models/eva.py | 12 +++++++----- timm/models/vision_transformer.py | 12 +++++++----- timm/models/vision_transformer_hybrid.py | 12 +++++++++++- 5 files changed, 34 insertions(+), 14 deletions(-) diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index 473b095a..ec8986d3 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -26,6 +26,7 @@ class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ output_fmt: Format + dynamic_img_pad: torch.jit.Final[bool] def __init__( self, @@ -38,6 +39,7 @@ class PatchEmbed(nn.Module): output_fmt: Optional[str] = None, bias: bool = True, strict_img_size: bool = True, + dynamic_img_pad: bool = False, ): super().__init__() self.patch_size = to_2tuple(patch_size) @@ -58,6 +60,7 @@ class PatchEmbed(nn.Module): self.flatten = flatten self.output_fmt = Format.NCHW self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() @@ -68,7 +71,7 @@ class PatchEmbed(nn.Module): if self.strict_img_size: _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).") _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).") - else: + elif not self.dynamic_img_pad: _assert( H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." @@ -77,7 +80,10 @@ class PatchEmbed(nn.Module): W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." ) - + if self.dynamic_img_pad: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC diff --git a/timm/models/deit.py b/timm/models/deit.py index c5459754..f80087e8 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -74,7 +74,7 @@ class VisionTransformerDistilled(VisionTransformer): self.distilled_training = enable def _pos_embed(self, x): - if self.dynamic_size: + if self.dynamic_img_size: B, H, W, C = x.shape pos_embed = resample_abs_pos_embed( self.pos_embed, diff --git a/timm/models/eva.py b/timm/models/eva.py index 7235132c..81bcce52 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -367,7 +367,8 @@ class Eva(nn.Module): use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, use_post_norm: bool = False, - dynamic_size: bool = False, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None, head_init_scale: float = 0.001, ): @@ -407,11 +408,11 @@ class Eva(nn.Module): self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 - self.dynamic_size = dynamic_size + self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False embed_args = {} - if dynamic_size: + if dynamic_img_size: # flatten deferred until after pos embed embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) self.patch_embed = PatchEmbed( @@ -419,6 +420,7 @@ class Eva(nn.Module): patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + dynamic_img_pad=dynamic_img_pad, **embed_args, ) num_patches = self.patch_embed.num_patches @@ -442,7 +444,7 @@ class Eva(nn.Module): self.rope = RotaryEmbeddingCat( embed_dim // num_heads, in_pixels=False, - feat_shape=None if dynamic_size else self.patch_embed.grid_size, + feat_shape=None if dynamic_img_size else self.patch_embed.grid_size, ref_feat_shape=ref_feat_shape, ) else: @@ -527,7 +529,7 @@ class Eva(nn.Module): self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.dynamic_size: + if self.dynamic_img_size: B, H, W, C = x.shape if self.pos_embed is not None: pos_embed = resample_abs_pos_embed( diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 8c9a9fa5..10b9296b 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -383,7 +383,7 @@ class VisionTransformer(nn.Module): A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ - dynamic_size: Final[bool] + dynamic_img_size: Final[bool] def __init__( self, @@ -401,9 +401,10 @@ class VisionTransformer(nn.Module): init_values: Optional[float] = None, class_token: bool = True, no_embed_class: bool = False, - dynamic_size: bool = False, pre_norm: bool = False, fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, drop_rate: float = 0., pos_drop_rate: float = 0., patch_drop_rate: float = 0., @@ -454,11 +455,11 @@ class VisionTransformer(nn.Module): self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 self.no_embed_class = no_embed_class - self.dynamic_size = dynamic_size + self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False embed_args = {} - if dynamic_size: + if dynamic_img_size: # flatten deferred until after pos embed embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) self.patch_embed = embed_layer( @@ -467,6 +468,7 @@ class VisionTransformer(nn.Module): in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, **embed_args, ) num_patches = self.patch_embed.num_patches @@ -554,7 +556,7 @@ class VisionTransformer(nn.Module): self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def _pos_embed(self, x): - if self.dynamic_size: + if self.dynamic_img_size: B, H, W, C = x.shape pos_embed = resample_abs_pos_embed( self.pos_embed, diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 7f04613e..e29bf73f 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -18,6 +18,7 @@ from typing import List, Optional, Tuple import torch import torch.nn as nn +import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to @@ -32,6 +33,7 @@ class HybridEmbed(nn.Module): Extract feature map from CNN, flatten, project to embedding dim. """ output_fmt: Format + dynamic_img_pad: torch.jit.Final[bool] def __init__( self, @@ -45,6 +47,7 @@ class HybridEmbed(nn.Module): flatten: bool = True, output_fmt: Optional[str] = None, strict_img_size: bool = True, + dynamic_img_pad: bool = False, ): super().__init__() assert isinstance(backbone, nn.Module) @@ -71,7 +74,8 @@ class HybridEmbed(nn.Module): feature_dim = self.backbone.feature_info.channels()[-1] else: feature_dim = self.backbone.num_features - assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 + if not dynamic_img_pad: + assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] if output_fmt is not None: @@ -82,6 +86,7 @@ class HybridEmbed(nn.Module): self.flatten = flatten self.output_fmt = Format.NCHW self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) @@ -89,6 +94,11 @@ class HybridEmbed(nn.Module): x = self.backbone(x) if isinstance(x, (list, tuple)): x = x[-1] # last feature if backbone outputs list/tuple of features + _, _, H, W = x.shape + if self.dynamic_img_pad: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC