From e9373b1b925b2546706d78d25294de596bad4bfe Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 May 2023 16:43:48 -0700 Subject: [PATCH] Cleanup before samvit merge. Resize abs posembed on the fly, undo some line-wraps, remove redundant unbind, fix HF hub weight load --- timm/layers/__init__.py | 2 +- timm/layers/patch_embed.py | 16 +++- timm/layers/pos_embed.py | 21 +++++ timm/models/vision_transformer_sam.py | 114 +++++++++++++------------- 4 files changed, 93 insertions(+), 60 deletions(-) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index d55faccc..caec5e69 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -36,7 +36,7 @@ from .padding import get_padding, get_same_padding, pad_same from .patch_dropout import PatchDropout from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed from .pool2d_same import AvgPool2dSame, create_pool2d -from .pos_embed import resample_abs_pos_embed +from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \ build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \ diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index 6ca6a1a1..b9a23921 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -37,6 +37,7 @@ class PatchEmbed(nn.Module): flatten: bool = True, output_fmt: Optional[str] = None, bias: bool = True, + strict_img_size: bool = True, ): super().__init__() self.patch_size = to_2tuple(patch_size) @@ -56,6 +57,7 @@ class PatchEmbed(nn.Module): # flatten spatial dim and transpose to channels last, kept for bwd compat self.flatten = flatten self.output_fmt = Format.NCHW + self.strict_img_size = strict_img_size self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() @@ -63,8 +65,18 @@ class PatchEmbed(nn.Module): def forward(self, x): B, C, H, W = x.shape if self.img_size is not None: - _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") - _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + if self.strict_img_size: + _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).") + _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).") + else: + _assert( + H % self.patch_size[0] == 0, + f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." + ) + _assert( + W % self.patch_size[1] == 0, + f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." + ) x = self.proj(x) if self.flatten: diff --git a/timm/layers/pos_embed.py b/timm/layers/pos_embed.py index c3afce76..ad10578a 100644 --- a/timm/layers/pos_embed.py +++ b/timm/layers/pos_embed.py @@ -52,3 +52,24 @@ def resample_abs_pos_embed( _logger.info(f'Resized position embedding: {old_size} to {new_size}.') return posemb + + +def resample_abs_pos_embed_nhwc( + posemb, + new_size: List[int], + interpolation: str = 'bicubic', + antialias: bool = True, + verbose: bool = False, +): + if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]: + return posemb + + # do the interpolation + posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2) + posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) + posemb = posemb.permute(0, 2, 3, 1) + + if not torch.jit.is_scripting() and verbose: + _logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.') + + return posemb \ No newline at end of file diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 5c395d2c..c8a8c53f 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -19,7 +19,8 @@ import torch.nn.functional as F import torch.utils.checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, Format +from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead,\ + Format, resample_abs_pos_embed_nhwc from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -71,24 +72,21 @@ class Attention(nn.Module): def forward(self, x): B, H, W, _ = x.shape - # qkv with shape (3, B, nHead, H * W, C) qkv = self.qkv(x).reshape( B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - # q, k, v with shape (B * nHead, H * W, C) - q, k, v = qkv.unbind(0) + # qkv with shape (3, B, nHead, H * W, C) q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + # q, k, v with shape (B * nHead, H * W, C) q, k = self.q_norm(q), self.k_norm(k) - - attn = (q * self.scale) @ k.transpose(-2, -1) + q = q * self.scale + attn = q @ k.transpose(-2, -1) if self.use_rel_pos: - attn = add_decomposed_rel_pos( - attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).view(B, self.num_heads, H, W, - - 1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) x = self.proj(x) return x @@ -136,13 +134,10 @@ class Block(nn.Module): proj_drop=proj_drop, norm_layer=norm_layer, use_rel_pos=use_rel_pos, - input_size=input_size if window_size == 0 else ( - window_size, window_size), + input_size=input_size if window_size == 0 else (window_size, window_size), ) - self.ls1 = LayerScale( - dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path1 = DropPath( - drop_path) if drop_path > 0. else nn.Identity() + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = mlp_layer( @@ -151,10 +146,8 @@ class Block(nn.Module): act_layer=act_layer, drop=proj_drop, ) - self.ls2 = LayerScale( - dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path2 = DropPath( - drop_path) if drop_path > 0. else nn.Identity() + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): shortcut = x @@ -194,10 +187,8 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w - x = x.view(B, Hp // window_size, window_size, - Wp // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous( - ).view(-1, window_size, window_size, C) + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows, (Hp, Wp) @@ -218,8 +209,7 @@ def window_unpartition( Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) - x = windows.view(B, Hp // window_size, Wp // window_size, - window_size, window_size, -1) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) if Hp > H or Wp > W: @@ -248,16 +238,14 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor size=max_rel_dist, mode="linear", ) - rel_pos_resized = rel_pos_resized.reshape( - -1, max_rel_dist).permute(1, 0) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) - relative_coords = (q_coords - k_coords) + \ - (k_size - 1) * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] @@ -331,7 +319,7 @@ class VisionTransformerSAM(nn.Module): drop_path_rate: float = 0., weight_init: str = '', embed_layer: Callable = partial( - PatchEmbed, output_fmt=Format.NHWC), + PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False), norm_layer: Optional[Callable] = nn.LayerNorm, act_layer: Optional[Callable] = nn.GELU, block_fn: Callable = Block, @@ -342,6 +330,7 @@ class VisionTransformerSAM(nn.Module): global_attn_indexes: Tuple[int, ...] = (), neck_chans: int = 256, global_pool: str = 'avg', + head_hidden_size: Optional[int] = None ): """ Args: @@ -370,6 +359,7 @@ class VisionTransformerSAM(nn.Module): window_size: Window size for window attention blocks. If 0, not use window attention. global_attn_indexes: Indexes for blocks using global attention. Used when window_size > 0. global_pool: Global pooling type. + head_hidden_size: If set, use NormMlpHead """ super().__init__() norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) @@ -388,14 +378,12 @@ class VisionTransformerSAM(nn.Module): embed_dim=embed_dim, bias=not pre_norm, # disable bias if pre-norm is used ) + grid_size = self.patch_embed.grid_size if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. - self.pos_embed = nn.Parameter( - torch.zeros(1, img_size // patch_size, - img_size // patch_size, embed_dim) - ) + self.pos_embed = nn.Parameter(torch.zeros(1, grid_size[0], grid_size[1], embed_dim)) else: - self.pos_embed = 0. + self.pos_embed = None self.pos_drop = nn.Dropout(p=pos_drop_rate) if patch_drop_rate > 0: self.patch_drop = PatchDropout( @@ -424,7 +412,7 @@ class VisionTransformerSAM(nn.Module): mlp_layer=mlp_layer, use_rel_pos=use_rel_pos, window_size=window_size if i not in global_attn_indexes else 0, - input_size=(img_size // patch_size, img_size // patch_size), + input_size=grid_size, ) for i in range(depth)]) @@ -451,12 +439,21 @@ class VisionTransformerSAM(nn.Module): neck_chans = embed_dim # Classifier Head - self.head = ClassifierHead( - neck_chans, - num_classes, - pool_type=global_pool, - drop_rate=drop_rate, - ) + if head_hidden_size: + self.head = NormMlpClassifierHead( + neck_chans, + num_classes, + hidden_size=head_hidden_size, + pool_type=global_pool, + drop_rate=drop_rate, + ) + else: + self.head = ClassifierHead( + neck_chans, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + ) @torch.jit.ignore def no_weight_decay(self): @@ -478,15 +475,14 @@ class VisionTransformerSAM(nn.Module): return self.head def reset_classifier(self, num_classes=0, global_pool=None): - self.head = self.head.reset(num_classes, global_pool) if num_classes > 0 else nn.Identity() - - def _pos_embed(self, x): - x = x + self.pos_embed - return self.pos_drop(x) + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.patch_embed(x) - x = self._pos_embed(x) + if self.pos_embed is not None: + # dynamically resize abs pos embedding if needed + x = x + resample_abs_pos_embed_nhwc(self.pos_embed, x.shape[1:3]) + x = self.pos_drop(x) x = self.patch_drop(x) x = self.norm_pre(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -507,15 +503,19 @@ class VisionTransformerSAM(nn.Module): def checkpoint_filter_fn( state_dict, - model + model, ): """ Remap SAM checkpoints -> timm """ + sam_checkpoint = 'image_encoder.patch_embed.proj.weight' in state_dict out_dict = {} for k, v in state_dict.items(): - if 'image_encoder.' in k: - new_k = k.replace('image_encoder.', '') - new_k = new_k.replace('mlp.lin', 'mlp.fc') - out_dict[new_k] = v + if k.startswith('image_encoder.'): + k = k[14:] + k = k.replace('mlp.lin', 'mlp.fc') + else: + if sam_checkpoint: + continue + out_dict[k] = v return out_dict @@ -535,19 +535,19 @@ default_cfgs = generate_default_cfgs({ # Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only) 'samvit_base_patch16.sa1b': _cfg( url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth', - # hf_hub_id='timm/', + hf_hub_id='timm/', license='apache-2.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 1024, 1024), crop_pct=1.0), 'samvit_large_patch16.sa1b': _cfg( url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', - # hf_hub_id='timm/', + hf_hub_id='timm/', license='apache-2.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 1024, 1024), crop_pct=1.0), 'samvit_huge_patch16.sa1b': _cfg( url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', - # hf_hub_id='timm/', + hf_hub_id='timm/', license='apache-2.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 1024, 1024), crop_pct=1.0),