From fee91fdd41f46ae9645c8356feea8752f29650c4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 14 Aug 2024 12:22:40 -0700 Subject: [PATCH 01/17] Update Hiera model for abswin, more stable weight init, layer-scale. ImageNet-12k weights for hiera_small_abswin, and two of the sbb vits with improved reg4 init. --- timm/models/hiera.py | 100 +++++++++++++++++++++++++----- timm/models/vision_transformer.py | 8 +++ 2 files changed, 91 insertions(+), 17 deletions(-) diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 200dd3e0..69af2f48 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -40,6 +40,7 @@ from ._registry import generate_default_cfgs, register_model from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function +from ._manipulate import named_apply __all__ = ['Hiera'] @@ -309,6 +310,21 @@ class MaskUnitAttention(nn.Module): return x +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + class HieraBlock(nn.Module): def __init__( self, @@ -317,6 +333,7 @@ class HieraBlock(nn.Module): heads: int, mlp_ratio: float = 4.0, drop_path: float = 0.0, + init_values: Optional[float] = None, norm_layer: nn.Module = nn.LayerNorm, act_layer: nn.Module = nn.GELU, q_stride: int = 1, @@ -348,13 +365,14 @@ class HieraBlock(nn.Module): window_size, use_mask_unit_attn ) + self.ls1 = LayerScale(dim_out, init_values=init_values) if init_values is not None else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity() self.norm2 = norm_layer(dim_out) self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer) + self.ls2 = LayerScale(dim_out, init_values=init_values) if init_values is not None else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity() - def forward(self, x: torch.Tensor) -> torch.Tensor: # Attention + Q Pooling x_norm = self.norm1(x) @@ -369,10 +387,10 @@ class HieraBlock(nn.Module): ], dim=-1, ) - x = x + self.drop_path1(self.attn(x_norm)) + x = x + self.drop_path1(self.ls1(self.attn(x_norm))) # MLP - x = x + self.drop_path2(self.mlp(self.norm2(x))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x @@ -470,6 +488,7 @@ class Hiera(nn.Module): mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1) # mask_unit_attn: which stages use mask unit attention? mask_unit_attn: Tuple[bool, ...] = (True, True, False, False), + use_expand_proj: bool = True, dim_mul: float = 2.0, head_mul: float = 2.0, patch_kernel: Tuple[int, ...] = (7, 7), @@ -477,6 +496,9 @@ class Hiera(nn.Module): patch_padding: Tuple[int, ...] = (3, 3), mlp_ratio: float = 4.0, drop_path_rate: float = 0.0, + init_values: Optional[float] = None, + fix_init: bool = True, + weight_init: str = '', norm_layer: Union[str, nn.Module] = "LayerNorm", drop_rate: float = 0.0, patch_drop_rate: float = 0.0, @@ -575,9 +597,11 @@ class Hiera(nn.Module): heads=num_heads, mlp_ratio=mlp_ratio, drop_path=dpr[i], + init_values=init_values, norm_layer=norm_layer, q_stride=(flat_q_stride if i in q_pool_blocks else 1), window_size=flat_mu_size, + use_expand_proj=use_expand_proj, use_mask_unit_attn=use_mask_unit_attn, ) embed_dim = dim_out @@ -605,19 +629,25 @@ class Hiera(nn.Module): elif self.pos_embed_abs is not None: nn.init.trunc_normal_(self.pos_embed_abs, std=0.02) nn.init.trunc_normal_(self.pos_embed_win, std=0.02) - self.apply(partial(self._init_weights)) + + if weight_init != 'skip': + if weight_init == 'jax': + named_apply(partial(_init_weight_jax, head_bias=-math.log(self.num_classes)), self) + else: + named_apply(_init_weight_vit, self) + if fix_init: + self.fix_init_weight() if isinstance(self.head.fc, nn.Linear): self.head.fc.weight.data.mul_(head_init_scale) self.head.fc.bias.data.mul_(head_init_scale) - def _init_weights(self, m, init_bias=0.02): - if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): - nn.init.trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, init_bias) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, init_bias) - nn.init.constant_(m.weight, 1.0) + def fix_init_weight(self): + def rescale(param, _layer_id): + param.div_(math.sqrt(2.0 * _layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) @torch.jit.ignore def no_weight_decay(self): @@ -829,6 +859,35 @@ class Hiera(nn.Module): return x +def _init_weight_vit(module, name, init_bias=0.02, head_bias=0.): + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): + if name.startswith('head.fc'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.trunc_normal_(module.weight, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + nn.init.constant_(module.bias, init_bias) + + +def _init_weight_jax(module, name, head_bias=0.): + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + from timm.models.layers import lecun_normal_ + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + def _cfg(url='', **kwargs): return { 'url': url, @@ -901,8 +960,9 @@ default_cfgs = generate_default_cfgs({ num_classes=0, ), - "hiera_small_abswin_256.untrained": _cfg( - #hf_hub_id='timm/', + "hiera_small_abswin_256.sbb2_ep200_in12k": _cfg( + hf_hub_id='timm/', + num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95, ), "hiera_base_abswin_256.untrained": _cfg( @@ -985,11 +1045,17 @@ def hiera_huge_224(pretrained=False, **kwargs): @register_model def hiera_small_abswin_256(pretrained=False, **kwargs): - model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16)) + model_args = dict( + embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16), + init_values=1e-5, weight_init='jax', use_expand_proj=False, + ) return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def hiera_base_abswin_256(pretrained=False, **kwargs): - model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16)) - return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs)) \ No newline at end of file + model_args = dict( + embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16), + init_values=1e-5, weight_init='jax', + ) + return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b613683f..64cab9ee 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1967,6 +1967,10 @@ default_cfgs = { 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_mediumd_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg( hf_hub_id='timm/', num_classes=11821, @@ -1980,6 +1984,10 @@ default_cfgs = { 'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg( hf_hub_id='timm/', num_classes=11821, From 2f3fed43b8888b07585242bf591a1d6cb9118b7c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 15 Aug 2024 11:14:38 -0700 Subject: [PATCH 02/17] Fix hiera init with num_classes=0, fix weight tag names for sbb2 hiera/vit weights, add LayerScale/LayerScale2d to layers --- timm/layers/__init__.py | 1 + timm/layers/layer_scale.py | 38 +++++++++++++++++++++++++++++++ timm/models/hiera.py | 34 +++++++-------------------- timm/models/vision_transformer.py | 4 ++-- 4 files changed, 49 insertions(+), 28 deletions(-) create mode 100644 timm/layers/layer_scale.py diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 38c82407..61115589 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -29,6 +29,7 @@ from .grid import ndgrid, meshgrid from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple from .hybrid_embed import HybridEmbed, HybridEmbedWithSize from .inplace_abn import InplaceAbn +from .layer_scale import LayerScale, LayerScale2d from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp diff --git a/timm/layers/layer_scale.py b/timm/layers/layer_scale.py new file mode 100644 index 00000000..08566b2b --- /dev/null +++ b/timm/layers/layer_scale.py @@ -0,0 +1,38 @@ +import torch +from torch import nn + + +class LayerScale(nn.Module): + """ LayerScale on tensors with channels in last-dim. + """ + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class LayerScale2d(nn.Module): + """ LayerScale for tensors with torch 2D NCHW layout. + """ + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma.view(1, -1, 1, 1) + return x.mul_(gamma) if self.inplace else x * gamma + diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 69af2f48..808053e9 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -31,10 +31,8 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer, to_2tuple - +from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple from ._registry import generate_default_cfgs, register_model from ._builder import build_model_with_cfg @@ -289,7 +287,6 @@ class MaskUnitAttention(nn.Module): """ Input should be of shape [batch, tokens, channels]. """ B, N, _ = x.shape num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1 - qkv = self.qkv(x).reshape(B, -1, num_windows, 3, self.heads, self.head_dim).permute(3, 0, 4, 2, 1, 5) q, k, v = qkv.unbind(0) @@ -310,21 +307,6 @@ class MaskUnitAttention(nn.Module): return x -class LayerScale(nn.Module): - def __init__( - self, - dim: int, - init_values: float = 1e-5, - inplace: bool = False, - ) -> None: - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - class HieraBlock(nn.Module): def __init__( self, @@ -342,7 +324,6 @@ class HieraBlock(nn.Module): use_mask_unit_attn: bool = False, ): super().__init__() - self.dim = dim self.dim_out = dim_out @@ -631,10 +612,8 @@ class Hiera(nn.Module): nn.init.trunc_normal_(self.pos_embed_win, std=0.02) if weight_init != 'skip': - if weight_init == 'jax': - named_apply(partial(_init_weight_jax, head_bias=-math.log(self.num_classes)), self) - else: - named_apply(_init_weight_vit, self) + init_fn = _init_weight_jax if weight_init == 'jax' else _init_weight_vit + named_apply(init_fn, self) if fix_init: self.fix_init_weight() if isinstance(self.head.fc, nn.Linear): @@ -868,11 +847,13 @@ def _init_weight_vit(module, name, init_bias=0.02, head_bias=0.): nn.init.trunc_normal_(module.weight, std=0.02) if isinstance(module, nn.Linear) and module.bias is not None: nn.init.constant_(module.bias, init_bias) + elif hasattr(module, 'init_weights'): + module.init_weights() def _init_weight_jax(module, name, head_bias=0.): if isinstance(module, nn.Linear): - if name.startswith('head'): + if name.startswith('head.fc'): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) else: @@ -960,7 +941,7 @@ default_cfgs = generate_default_cfgs({ num_classes=0, ), - "hiera_small_abswin_256.sbb2_ep200_in12k": _cfg( + "hiera_small_abswin_256.sbb2_e200_in12k": _cfg( hf_hub_id='timm/', num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95, @@ -1007,6 +988,7 @@ def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera: **kwargs, ) + @register_model def hiera_tiny_224(pretrained=False, **kwargs): model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2)) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 64cab9ee..afb5e002 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1967,7 +1967,7 @@ default_cfgs = { 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_mediumd_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg( + 'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg( hf_hub_id='timm/', num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95), @@ -1984,7 +1984,7 @@ default_cfgs = { 'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_betwixt_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg( + 'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg( hf_hub_id='timm/', num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95), From a50e53d41fa937c0baaa6e9318fde252e3e8b389 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 15 Aug 2024 17:46:36 -0700 Subject: [PATCH 03/17] Rename global pos embed for Hiera abswin, factor out commonly used vit weight init fns to layers. Add a channels-last ver of normmlp head. --- timm/layers/__init__.py | 5 ++- timm/layers/classifier.py | 76 +++++++++++++++++++++++++++++++++++++- timm/layers/weight_init.py | 44 +++++++++++++++++++++- timm/models/hiera.py | 72 ++++++++++-------------------------- 4 files changed, 141 insertions(+), 56 deletions(-) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 61115589..49ffa0ce 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -5,7 +5,7 @@ from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttention from .attention_pool import AttentionPoolLatent from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding from .blur_pool import BlurPool2d, create_aa -from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead +from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead from .cond_conv2d import CondConv2d, get_condconv_initializer from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \ set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn @@ -57,4 +57,5 @@ from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2d from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .trace_utils import _assert, _float_to_int from .typing import LayerType, PadType -from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ +from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_, \ + init_weight_jax, init_weight_vit diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 2441c050..1cbab683 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -134,7 +134,8 @@ class ClassifierHead(nn.Module): class NormMlpClassifierHead(nn.Module): - + """ A Pool -> Norm -> Mlp Classifier Head for '2D' NCHW tensors + """ def __init__( self, in_features: int, @@ -204,3 +205,76 @@ class NormMlpClassifierHead(nn.Module): return x x = self.fc(x) return x + + +class ClNormMlpClassifierHead(nn.Module): + """ A Pool -> Norm -> Mlp Classifier Head for n-D NxxC tensors + """ + def __init__( + self, + in_features: int, + num_classes: int, + hidden_size: Optional[int] = None, + pool_type: str = 'avg', + drop_rate: float = 0., + norm_layer: Union[str, Callable] = 'layernorm', + act_layer: Union[str, Callable] = 'gelu', + input_fmt: str = 'NHWC', + ): + """ + Args: + in_features: The number of input features. + num_classes: The number of classes for the final classifier layer (output). + hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None. + pool_type: Global pooling type, pooling disabled if empty string (''). + drop_rate: Pre-classifier dropout rate. + norm_layer: Normalization layer type. + act_layer: MLP activation layer type (only used if hidden_size is not None). + """ + super().__init__() + self.in_features = in_features + self.hidden_size = hidden_size + self.num_features = in_features + assert pool_type in ('', 'avg', 'max', 'avgmax') + self.pool_type = pool_type + assert input_fmt in ('NHWC', 'NLC') + self.pool_dim = 1 if input_fmt == 'NLC' else (1, 2) + norm_layer = get_norm_layer(norm_layer) + act_layer = get_act_layer(act_layer) + + self.norm = norm_layer(in_features) + if hidden_size: + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(in_features, hidden_size)), + ('act', act_layer()), + ])) + self.num_features = hidden_size + else: + self.pre_logits = nn.Identity() + self.drop = nn.Dropout(drop_rate) + self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def reset(self, num_classes: int, pool_type: Optional[str] = None): + if pool_type is not None: + self.pool_type = pool_type + self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def _global_pool(self, x): + if self.pool_type: + if self.pool_type == 'avg': + x = x.mean(dim=self.pool_dim) + elif self.pool_type == 'max': + x = x.amax(dim=self.pool_dim) + elif self.pool_type == 'avgmax': + x = 0.5 * (x.amax(dim=self.pool_dim) + x.mean(dim=self.pool_dim)) + return x + + def forward(self, x, pre_logits: bool = False): + x = self._global_pool(x) + x = self.norm(x) + x = self.pre_logits(x) + x = self.drop(x) + if pre_logits: + return x + x = self.fc(x) + return x diff --git a/timm/layers/weight_init.py b/timm/layers/weight_init.py index 943e4f4c..d1127ecb 100644 --- a/timm/layers/weight_init.py +++ b/timm/layers/weight_init.py @@ -1,7 +1,7 @@ import torch import math import warnings - +from torch import nn from torch.nn.init import _calculate_fan_in_and_fan_out @@ -123,3 +123,45 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): def lecun_normal_(tensor): variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') + + +def init_weight_vit( + module: nn.Module, + name: str, + init_bias: float = 0.02, + head_bias: float = 0., + classifier_name: str = 'head' +): + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): + if name.startswith(classifier_name): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.trunc_normal_(module.weight, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + nn.init.constant_(module.bias, init_bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weight_jax( + module: nn.Module, + name: str, + head_bias: float = 0., + classifier_name: str = 'head', +): + if isinstance(module, nn.Linear): + if name.startswith(classifier_name): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 808053e9..ec5d8b7b 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -32,7 +32,8 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple +from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple, \ + init_weight_vit, init_weight_jax from ._registry import generate_default_cfgs, register_model from ._builder import build_model_with_cfg @@ -391,7 +392,7 @@ class NormClassifierHead(nn.Module): self.pool_type = pool_type self.norm = norm_layer(in_features) self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity() - self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity() + self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity() def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False): if pool_type is not None: @@ -400,7 +401,7 @@ class NormClassifierHead(nn.Module): if other: # reset other non-fc layers self.norm = nn.Identity() - self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity() + self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity() def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: if self.pool_type == 'avg': @@ -486,7 +487,7 @@ class Hiera(nn.Module): head_init_scale: float = 0.001, sep_pos_embed: bool = False, abs_win_pos_embed: bool = False, - abs_pos_size: Tuple[int, int] = (14, 14), + global_pos_size: Tuple[int, int] = (14, 14), ): super().__init__() self.num_classes = num_classes @@ -513,11 +514,9 @@ class Hiera(nn.Module): patch_kernel, patch_stride, patch_padding, - #reshape=False, # leave spatial / temporal dims in output ) self.pos_embed: Optional[nn.Parameter] = None - self.pos_embed_abs: Optional[nn.Parameter] = None self.pos_embed_win: Optional[nn.Parameter] = None self.pos_embed_spatial: Optional[nn.Parameter] = None self.pos_embed_temporal: Optional[nn.Parameter] = None @@ -531,7 +530,7 @@ class Hiera(nn.Module): else: if abs_win_pos_embed: # absolute win, params NCHW to make tile & interpolate more natural before add & reshape - self.pos_embed_abs = nn.Parameter(torch.zeros(1, embed_dim, *abs_pos_size)) + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *global_pos_size)) self.pos_embed_win = nn.Parameter(torch.zeros(1, embed_dim, *mask_unit_size)) else: self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim)) @@ -555,7 +554,7 @@ class Hiera(nn.Module): # Transformer blocks cur_stage = 0 depth = sum(stages) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList() self.feature_info = [] for i in range(depth): @@ -607,12 +606,12 @@ class Hiera(nn.Module): else: if self.pos_embed is not None: nn.init.trunc_normal_(self.pos_embed, std=0.02) - elif self.pos_embed_abs is not None: - nn.init.trunc_normal_(self.pos_embed_abs, std=0.02) + if self.pos_embed_win is not None: nn.init.trunc_normal_(self.pos_embed_win, std=0.02) if weight_init != 'skip': - init_fn = _init_weight_jax if weight_init == 'jax' else _init_weight_vit + init_fn = init_weight_jax if weight_init == 'jax' else init_weight_vit + init_fn = partial(init_fn, classifier_name='head.fc') named_apply(init_fn, self) if fix_init: self.fix_init_weight() @@ -681,20 +680,20 @@ class Hiera(nn.Module): return mask.bool() def _pos_embed(self, x) -> torch.Tensor: - if self.pos_embed is not None: - pos_embed = self.pos_embed - elif self.pos_embed_abs is not None: + if self.pos_embed_win is not None: # absolute win position embedding, from # Window Attention is Bugged: How not to Interpolate Position Embeddings (https://arxiv.org/abs/2311.05613) pos_embed_win = self.pos_embed_win.tile(self.mask_spatial_shape) - pos_embed_abs = F.interpolate( - self.pos_embed_abs, + pos_embed = F.interpolate( + self.pos_embed, size=pos_embed_win.shape[-2:], mode='bicubic', antialias=True, ) - pos_embed = pos_embed_abs + pos_embed_win + pos_embed = pos_embed + pos_embed_win pos_embed = pos_embed.flatten(2).transpose(1, 2) + elif self.pos_embed is not None: + pos_embed = self.pos_embed else: pos_embed = ( self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1) @@ -838,37 +837,6 @@ class Hiera(nn.Module): return x -def _init_weight_vit(module, name, init_bias=0.02, head_bias=0.): - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): - if name.startswith('head.fc'): - nn.init.zeros_(module.weight) - nn.init.constant_(module.bias, head_bias) - else: - nn.init.trunc_normal_(module.weight, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: - nn.init.constant_(module.bias, init_bias) - elif hasattr(module, 'init_weights'): - module.init_weights() - - -def _init_weight_jax(module, name, head_bias=0.): - if isinstance(module, nn.Linear): - if name.startswith('head.fc'): - nn.init.zeros_(module.weight) - nn.init.constant_(module.bias, head_bias) - else: - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) - elif isinstance(module, nn.Conv2d): - from timm.models.layers import lecun_normal_ - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif hasattr(module, 'init_weights'): - module.init_weights() - - def _cfg(url='', **kwargs): return { 'url': url, @@ -972,6 +940,8 @@ def checkpoint_filter_fn(state_dict, model=None): k = k.replace('encoder_norm.', 'head.norm.') elif k.startswith('norm.'): k = k.replace('norm.', 'head.norm.') + if k == 'pos_embed_abs': + k = 'pos_embed' output[k] = v return output @@ -1028,7 +998,7 @@ def hiera_huge_224(pretrained=False, **kwargs): @register_model def hiera_small_abswin_256(pretrained=False, **kwargs): model_args = dict( - embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16), + embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, global_pos_size=(16, 16), init_values=1e-5, weight_init='jax', use_expand_proj=False, ) return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -1037,7 +1007,5 @@ def hiera_small_abswin_256(pretrained=False, **kwargs): @register_model def hiera_base_abswin_256(pretrained=False, **kwargs): model_args = dict( - embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16), - init_values=1e-5, weight_init='jax', - ) + embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, init_values=1e-5, weight_init='jax') return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs)) From f2cfb4c677f191521bbd4d46d9aae591d0dc4b28 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 15 Aug 2024 17:58:15 -0700 Subject: [PATCH 04/17] Add WIP HieraDet impl (SAM2 backbone support) --- timm/layers/create_act.py | 2 + timm/models/__init__.py | 1 + timm/models/hieradet_sam2.py | 564 +++++++++++++++++++++++++++++++++++ 3 files changed, 567 insertions(+) create mode 100644 timm/models/hieradet_sam2.py diff --git a/timm/layers/create_act.py b/timm/layers/create_act.py index 6bbbc14b..c734785d 100644 --- a/timm/layers/create_act.py +++ b/timm/layers/create_act.py @@ -97,6 +97,7 @@ def get_act_fn(name: Union[Callable, str] = 'relu'): return None if isinstance(name, Callable): return name + name = name.lower() if not (is_exportable() or is_scriptable()): # If not exporting or scripting the model, first look for a memory-efficient version with # custom autograd, then fallback @@ -117,6 +118,7 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): return name if not name: return None + name = name.lower() if not (is_exportable() or is_scriptable()): if name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 7c926319..5e723724 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -27,6 +27,7 @@ from .ghostnet import * from .hardcorenas import * from .hgnet import * from .hiera import * +from .hieradet_sam2 import * from .hrnet import * from .inception_next import * from .inception_resnet_v2 import * diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py new file mode 100644 index 00000000..2ba167ae --- /dev/null +++ b/timm/models/hieradet_sam2.py @@ -0,0 +1,564 @@ +import math +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.jit import Final + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, PatchDropout, \ + get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit + +from ._builder import build_model_with_cfg +from ._features import feature_take_indices +from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv +from ._registry import generate_default_cfgs, register_model, register_model_deprecations + + +def do_pool( + x: torch.Tensor, + pool: nn.Module, + norm: nn.Module = None, +) -> torch.Tensor: + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: List[int], hw: List[int]): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + x = x[:, :H, :W, :].contiguous() + return x + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + + self.num_heads = num_heads + head_dim = dim_out // num_heads + self.scale = head_dim**-0.5 + + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2).reshape(B, H, W, -1) + + x = self.proj(x) + return x + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + q_stride: Tuple[int, int] = None, + norm_layer: Union[nn.Module, str] = "LayerNorm", + act_layer: Union[nn.Module, str] = "GELU", + window_size: int = 0, + ): + super().__init__() + norm_layer = get_norm_layer(norm_layer) + act_layer = get_act_layer(act_layer) + self.window_size = window_size + self.dim = dim + self.dim_out = dim_out + + self.norm1 = norm_layer(dim) + self.pool = None + self.q_stride = q_stride + if self.q_stride: + self.pool = nn.MaxPool2d( + kernel_size=q_stride, + stride=q_stride, + ceil_mode=False, + ) + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = Mlp( + dim_out, + int(dim_out * mlp_ratio), + act_layer=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + feat_size = x.shape[1:3] + pad_hw = 0, 0 + if window_size > 0: + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + feat_size = shortcut.shape[1:3] + + pad_h = (window_size - feat_size[0] % window_size) % window_size + pad_w = (window_size - feat_size[1] % window_size) % window_size + pad_hw = (feat_size[0] + pad_h, feat_size[1] + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, feat_size) + + x = shortcut + self.drop_path(x) + + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class HieraPatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, ...] = (7, 7), + stride: Tuple[int, ...] = (4, 4), + padding: Tuple[int, ...] = (3, 3), + in_chans: int = 3, + embed_dim: int = 768, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x + + +class HieraDet(nn.Module): + """ + Reference: https://arxiv.org/abs/2306.00989 + """ + + def __init__( + self, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + patch_kernel: Tuple[int, ...] = (7, 7), + patch_stride: Tuple[int, ...] = (4, 4), + patch_padding: Tuple[int, ...] = (3, 3), + patch_size: Optional[Tuple[int, ...]] = None, + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + global_pos_size: Tuple[int, int] = (7, 7), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + weight_init: str = '', + fix_init: bool = True, + head_init_scale: float = 0.001, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, # stochastic depth + norm_layer: Union[nn.Module, str] = "LayerNorm", + act_layer: Union[nn.Module, str] = "GELU", + ): + super().__init__() + norm_layer = get_norm_layer(norm_layer) + act_layer = get_act_layer(act_layer) + assert len(stages) == len(window_spec) + self.num_classes = num_classes + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + + if patch_size is not None: + # use a non-overlapping vit style patch embed + self.patch_embed = PatchEmbed( + img_size=None, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + output_fmt='NHWC', + dynamic_img_pad=True, + ) + else: + self.patch_embed = HieraPatchEmbed( + kernel_size=patch_kernel, + stride=patch_stride, + padding=patch_padding, + in_chans=in_chans, + embed_dim=embed_dim, + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.global_pos_size = global_pos_size + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.global_pos_size)) + self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.Sequential() + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + norm_layer=norm_layer, + act_layer=act_layer, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if True else + [self.blocks[-1].dim_out] + ) + + self.num_features = self.head_hidden_size = embed_dim + self.head = ClNormMlpClassifierHead( + embed_dim, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + norm_layer=norm_layer, + ) + + # Initialize everything + if self.pos_embed is not None: + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + if self.pos_embed_window is not None: + nn.init.trunc_normal_(self.pos_embed_window, std=0.02) + + if weight_init != 'skip': + init_fn = init_weight_jax if weight_init == 'jax' else init_weight_vit + init_fn = partial(init_fn, classifier_name='head.fc') + named_apply(init_fn, self) + + if fix_init: + self.fix_init_weight() + + if isinstance(self.head, ClNormMlpClassifierHead) and isinstance(self.head.fc, nn.Linear): + self.head.fc.weight.data.mul_(head_init_scale) + self.head.fc.bias.data.mul_(head_init_scale) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + tile_h = pos_embed.shape[-2] // window_embed.shape[-2] + tile_w = pos_embed.shape[-1] // window_embed.shape[-1] + pos_embed = pos_embed + window_embed.tile((tile_h, tile_w)) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def fix_init_weight(self): + def rescale(param, _layer_id): + param.div_(math.sqrt(2.0 * _layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + @torch.jit.ignore + def no_weight_decay(self): + return ['pos_embed', 'pos_embed_win'] + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|pos_embed_abs|pos_embed_win|patch_embed', + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes + self.head.reset(num_classes, pool_type=global_pool) + + def forward_features(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.patch_embed(x) # BHWC + x = x + self._get_pos_embed(x.shape[1:3]) + for i, blk in enumerate(self.blocks): + x = blk(x) + return x + + def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor: + x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) + return x + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +# NOTE sam2 appears to use 1024x1024 for all models, but T, S, & B+ have windows that fit multiples of 224. +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 0, 'input_size': (3, 896, 896), 'pool_size': (28, 28), + 'crop_pct': 1.0, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + "sam2_hiera_tiny.r224": _cfg( + hf_hub_id='facebook/sam2-hiera-tiny', + hf_hub_filename='sam2_hiera_tiny.pt', + input_size=(3, 224, 224), pool_size=(7, 7), + ), # FIXME reduced res for testing + "sam2_hiera_tiny.r896": _cfg( + hf_hub_id='facebook/sam2-hiera-tiny', + hf_hub_filename='sam2_hiera_tiny.pt', + ), + "sam2_hiera_small": _cfg( + hf_hub_id='facebook/sam2-hiera-small', + hf_hub_filename='sam2_hiera_small.pt', + ), + "sam2_hiera_base_plus": _cfg( + hf_hub_id='facebook/sam2-hiera-base-plus', + hf_hub_filename='sam2_hiera_base_plus.pt', + ), + "sam2_hiera_large": _cfg( + hf_hub_id='facebook/sam2-hiera-large', + hf_hub_filename='sam2_hiera_large.pt', + input_size=(3, 1024, 1024), pool_size=(32, 32), + ), +}) + + +def checkpoint_filter_fn(state_dict, model=None, prefix=''): + state_dict = state_dict.get('model', state_dict) + + output = {} + for k, v in state_dict.items(): + if k.startswith(prefix): + k = k.replace(prefix, '') + else: + continue + k = k.replace('mlp.layers.0', 'mlp.fc1') + k = k.replace('mlp.layers.1', 'mlp.fc2') + output[k] = v + return output + + +def _create_hiera_det(variant: str, pretrained: bool = False, **kwargs) -> HieraDet: + out_indices = kwargs.pop('out_indices', 4) + if True: # kwargs.get('pretrained_cfg', '') == '?': + # SAM2 pretrained weights have no classifier or final norm-layer (`head.norm`) + # This is workaround loading with num_classes=0 w/o removing norm-layer. + kwargs.setdefault('pretrained_strict', False) + checkpoint_prefix = 'image_encoder.trunk.' if 'sam2' in variant else '' + return build_model_with_cfg( + HieraDet, + variant, + pretrained, + pretrained_filter_fn=partial(checkpoint_filter_fn, prefix=checkpoint_prefix), + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) + + +@register_model +def sam2_hiera_tiny(pretrained=False, **kwargs): + model_args = dict(stages=(1, 2, 7, 2), global_att_blocks=(5, 7, 9)) + return _create_hiera_det('sam2_hiera_tiny', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def sam2_hiera_small(pretrained=False, **kwargs): + model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13)) + return _create_hiera_det('sam2_hiera_small', pretrained=pretrained, **dict(model_args, **kwargs)) + + +# @register_model +# def sam2_hiera_base(pretrained=False, **kwargs): +# model_args = dict() +# return _create_hiera_det('sam2_hiera_base', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def sam2_hiera_base_plus(pretrained=False, **kwargs): + model_args = dict(embed_dim=112, num_heads=2, global_pos_size=(14, 14)) + return _create_hiera_det('sam2_hiera_base_plus', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def sam2_hiera_large(pretrained=False, **kwargs): + model_args = dict( + embed_dim=144, + num_heads=2, + stages=(2, 6, 36, 4), + global_att_blocks=(23, 33, 43), + window_spec=(8, 4, 16, 8), + ) + return _create_hiera_det('sam2_hiera_large', pretrained=pretrained, **dict(model_args, **kwargs)) From 962958723ca443644c8deb70e9599d96d4540172 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 16 Aug 2024 11:10:04 -0700 Subject: [PATCH 05/17] More Hiera updates. Add forward_intermediates to hieradat/sam2 impl. Make both use same classifier module. Add coarse bool to intermediates. --- timm/layers/classifier.py | 5 +- timm/models/hiera.py | 74 ++++++++---------------- timm/models/hieradet_sam2.py | 106 ++++++++++++++++++++++++++++------- 3 files changed, 114 insertions(+), 71 deletions(-) diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 1cbab683..5e425fe6 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -254,9 +254,12 @@ class ClNormMlpClassifierHead(nn.Module): self.drop = nn.Dropout(drop_rate) self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - def reset(self, num_classes: int, pool_type: Optional[str] = None): + def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False): if pool_type is not None: self.pool_type = pool_type + if reset_other: + self.pre_logits = nn.Identity() + self.norm = nn.Identity() self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def _global_pool(self, x): diff --git a/timm/models/hiera.py b/timm/models/hiera.py index ec5d8b7b..78d32752 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -32,8 +32,8 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple, \ - init_weight_vit, init_weight_jax +from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \ + _assert, get_norm_layer, to_2tuple, init_weight_vit, init_weight_jax from ._registry import generate_default_cfgs, register_model from ._builder import build_model_with_cfg @@ -376,44 +376,6 @@ class HieraBlock(nn.Module): return x -class NormClassifierHead(nn.Module): - def __init__( - self, - in_features: int, - num_classes: int, - pool_type: str = 'avg', - drop_rate: float = 0.0, - norm_layer: Union[str, Callable] = 'layernorm', - ): - super().__init__() - norm_layer = get_norm_layer(norm_layer) - assert pool_type in ('avg', '') - self.in_features = self.num_features = in_features - self.pool_type = pool_type - self.norm = norm_layer(in_features) - self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity() - self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity() - - def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False): - if pool_type is not None: - assert pool_type in ('avg', '') - self.pool_type = pool_type - if other: - # reset other non-fc layers - self.norm = nn.Identity() - self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity() - - def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: - if self.pool_type == 'avg': - x = x.mean(dim=1) - x = self.norm(x) - x = self.drop(x) - if pre_logits: - return x - x = self.fc(x) - return x - - class PatchEmbed(nn.Module): """Patch embed that supports any number of spatial dimensions (1d, 2d, 3d).""" @@ -591,12 +553,13 @@ class Hiera(nn.Module): self.blocks.append(block) self.num_features = self.head_hidden_size = embed_dim - self.head = NormClassifierHead( + self.head = ClNormMlpClassifierHead( embed_dim, num_classes, pool_type=global_pool, drop_rate=drop_rate, norm_layer=norm_layer, + input_fmt='NLC', ) # Initialize everything @@ -651,9 +614,9 @@ class Hiera(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False): self.num_classes = num_classes - self.head.reset(num_classes, global_pool, other=other) + self.head.reset(num_classes, global_pool, reset_other=reset_other) def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor: """ @@ -716,6 +679,7 @@ class Hiera(nn.Module): stop_early: bool = True, output_fmt: str = 'NCHW', intermediates_only: bool = False, + coarse: bool = True, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. @@ -730,10 +694,13 @@ class Hiera(nn.Module): """ assert not norm, 'normalization of features not supported' - assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.' - take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) - take_indices = [self.stage_ends[i] for i in take_indices] - max_index = self.stage_ends[max_index] + assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.' + if coarse: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + else: + take_indices, max_index = feature_take_indices(len(self.blocks), indices) if mask is not None: patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape @@ -755,7 +722,8 @@ class Hiera(nn.Module): for i, blk in enumerate(blocks): x = blk(x) if i in take_indices: - intermediates.append(self.reroll(x, i, mask=mask).permute(0, 3, 1, 2)) + x_int = self.reroll(x, i, mask=mask) + intermediates.append(x_int.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x_int) if intermediates_only: return intermediates @@ -767,14 +735,18 @@ class Hiera(nn.Module): indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, + coarse: bool = True, ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) - max_index = self.stage_ends[max_index] + if coarse: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + max_index = self.stage_ends[max_index] + else: + take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_head: - self.head.reset(0, other=True) + self.head.reset(0, reset_other=True) return take_indices def forward_features( diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 2ba167ae..d5a78679 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -328,18 +328,16 @@ class HieraDet(nn.Module): self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.global_pos_size)) self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])) - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, depth) - ] # stochastic depth decay rule - - cur_stage = 1 + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + cur_stage = 0 self.blocks = nn.Sequential() + self.feature_info = [] for i in range(depth): dim_out = embed_dim # lags by a block, so first block of # next stage uses an initial window size # of previous stage and final window size of current stage - window_size = self.window_spec[cur_stage - 1] + window_size = self.window_spec[cur_stage] if self.global_att_blocks is not None: window_size = 0 if i in self.global_att_blocks else window_size @@ -362,6 +360,9 @@ class HieraDet(nn.Module): embed_dim = dim_out self.blocks.append(block) + if i in self.stage_ends: + self.feature_info += [ + dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')] self.channel_list = ( [self.blocks[i].dim_out for i in self.stage_ends[::-1]] @@ -397,15 +398,15 @@ class HieraDet(nn.Module): self.head.fc.weight.data.mul_(head_init_scale) self.head.fc.bias.data.mul_(head_init_scale) - def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: - h, w = hw + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + h, w = x.shape[1:3] window_embed = self.pos_embed_window pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") tile_h = pos_embed.shape[-2] // window_embed.shape[-2] tile_w = pos_embed.shape[-1] // window_embed.shape[-1] pos_embed = pos_embed + window_embed.tile((tile_h, tile_w)) pos_embed = pos_embed.permute(0, 2, 3, 1) - return pos_embed + return x + pos_embed def fix_init_weight(self): def rescale(param, _layer_id): @@ -417,13 +418,13 @@ class HieraDet(nn.Module): @torch.jit.ignore def no_weight_decay(self): - return ['pos_embed', 'pos_embed_win'] + return ['pos_embed', 'pos_embed_window'] @torch.jit.ignore def group_matcher(self, coarse: bool = False) -> Dict: return dict( - stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|pos_embed_abs|pos_embed_win|patch_embed', - blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + stem=r'^pos_embed|pos_embed_window|patch_embed', + blocks=[(r'^blocks\.(\d+)', None)] ) @torch.jit.ignore @@ -434,13 +435,83 @@ class HieraDet(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False): self.num_classes = num_classes - self.head.reset(num_classes, pool_type=global_pool) + self.head.reset(num_classes, pool_type=global_pool, reset_other=reset_other) + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + coarse: bool = True, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + coarse: Take coarse features (stage ends) if true, otherwise all block featrures + Returns: + + """ + assert not norm, 'normalization of features not supported' + assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.' + if coarse: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + else: + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + x = self.patch_embed(x) + x = self._pos_embed(x) + + intermediates = [] + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + x_out = x.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x + intermediates.append(x_out) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + coarse: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + if coarse: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + max_index = self.stage_ends[max_index] + else: + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_head: + self.head.reset(0, reset_other=True) + return take_indices def forward_features(self, x: torch.Tensor) -> List[torch.Tensor]: x = self.patch_embed(x) # BHWC - x = x + self._get_pos_embed(x.shape[1:3]) + x = self._pos_embed(x) for i, blk in enumerate(self.blocks): x = blk(x) return x @@ -449,10 +520,7 @@ class HieraDet(nn.Module): x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) return x - def forward( - self, - x: torch.Tensor, - ) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) x = self.forward_head(x) return x From 146c2fbe3481248f044d1545aa765b032f35850d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 16 Aug 2024 12:10:00 -0700 Subject: [PATCH 06/17] Add resnet50d and efficientnet_b1 ra4 (mnv4) hparam weights --- timm/models/efficientnet.py | 8 ++++++-- timm/models/resnet.py | 5 +++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 3aee0342..09d6c66c 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -1290,8 +1290,12 @@ default_cfgs = generate_default_cfgs({ 'efficientnet_b0.ra4_e3600_r224_in1k': _cfg( hf_hub_id='timm/', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, - crop_pct=0.9, test_input_size=(3, 256, 256), test_crop_pct=1.0 - ), + crop_pct=0.9, test_input_size=(3, 256, 256), test_crop_pct=1.0), + 'efficientnet_b1.ra4_e3600_r240_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), crop_pct=0.9, + test_input_size=(3, 288, 288), test_crop_pct=1.0), 'efficientnet_b1.ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', hf_hub_id='timm/', diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 1d60deca..a80954cc 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -783,6 +783,11 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', first_conv='conv1.0'), + 'resnet50d.ra4_e3600_r224_in1k': _rcfg( + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0, + first_conv='conv1.0'), 'resnet50d.a1_in1k': _rcfg( hf_hub_id='timm/', url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a1_0-e20cff14.pth', From e03538117183e50e052ad5c2e401900ecc560869 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 16 Aug 2024 13:36:02 -0700 Subject: [PATCH 07/17] Move padding out of windowing code for hieradet, fix torchscript typing issues, make pooling MaxPool unique instances across two modules --- timm/models/hieradet_sam2.py | 126 ++++++++++++++++------------------- 1 file changed, 56 insertions(+), 70 deletions(-) diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index d5a78679..4f81b8b2 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -9,7 +9,7 @@ from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, PatchDropout, \ - get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit + get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -17,25 +17,7 @@ from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations -def do_pool( - x: torch.Tensor, - pool: nn.Module, - norm: nn.Module = None, -) -> torch.Tensor: - if pool is None: - return x - # (B, H, W, C) -> (B, C, H, W) - x = x.permute(0, 3, 1, 2) - x = pool(x) - # (B, C, H', W') -> (B, H', W', C) - x = x.permute(0, 2, 3, 1) - if norm: - x = norm(x) - - return x - - -def window_partition(x, window_size): +def window_partition(x, window_size: Tuple[int, int]): """ Partition into non-overlapping windows with padding if needed. Args: @@ -46,41 +28,35 @@ def window_partition(x, window_size): (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape - - pad_h = (window_size - H % window_size) % window_size - pad_w = (window_size - W % window_size) % window_size - x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) - Hp, Wp = H + pad_h, W + pad_w - - x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) - windows = ( - x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - ) - return windows, (Hp, Wp) + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows -def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: List[int], hw: List[int]): +def window_unpartition(windows: torch.Tensor, window_size: Tuple[int, int], hw: Tuple[int, int]): """ Window unpartition into original sequences and removing padding. Args: x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. - pad_hw (Tuple): padded height and width (Hp, Wp). hw (Tuple): original height and width (H, W) before padding. Returns: x: unpartitioned sequences with [B, H, W, C]. """ - Hp, Wp = pad_hw H, W = hw - B = windows.shape[0] // (Hp * Wp // window_size // window_size) - x = windows.view( - B, Hp // window_size, Wp // window_size, window_size, window_size, -1 - ) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) - x = x[:, :H, :W, :].contiguous() + B = windows.shape[0] // (H * W // window_size[0] // window_size[1]) + x = windows.view(B, H // window_size[0], W // window_size[0], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x +def _calc_pad(H: int, W: int, window_size: Tuple[int, int]) -> Tuple[int, int, int, int]: + pad_h = (window_size[0] - H % window_size[0]) % window_size[0] + pad_w = (window_size[1] - W % window_size[1]) % window_size[1] + Hp, Wp = H + pad_h, W + pad_w + return Hp, Wp, pad_h, pad_w + + class MultiScaleAttention(nn.Module): def __init__( self, @@ -112,8 +88,9 @@ class MultiScaleAttention(nn.Module): q, k, v = torch.unbind(qkv, 2) # Q pooling (for downsample at stage changes) - if self.q_pool: - q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + if self.q_pool is not None: + q = q.reshape(B, H, W, -1).permute(0, 3, 1, 2) # to BCHW for pool + q = self.q_pool(q).permute(0, 2, 3, 1) H, W = q.shape[1:3] # downsampled shape q = q.reshape(B, H * W, self.num_heads, -1) @@ -138,7 +115,7 @@ class MultiScaleBlock(nn.Module): num_heads: int, mlp_ratio: float = 4.0, drop_path: float = 0.0, - q_stride: Tuple[int, int] = None, + q_stride: Optional[Tuple[int, int]] = None, norm_layer: Union[nn.Module, str] = "LayerNorm", act_layer: Union[nn.Module, str] = "GELU", window_size: int = 0, @@ -146,24 +123,26 @@ class MultiScaleBlock(nn.Module): super().__init__() norm_layer = get_norm_layer(norm_layer) act_layer = get_act_layer(act_layer) - self.window_size = window_size + self.window_size = to_2tuple(window_size) + self.is_windowed = any(self.window_size) self.dim = dim self.dim_out = dim_out - - self.norm1 = norm_layer(dim) - self.pool = None self.q_stride = q_stride if self.q_stride: - self.pool = nn.MaxPool2d( + q_pool = nn.MaxPool2d( kernel_size=q_stride, stride=q_stride, ceil_mode=False, ) + else: + q_pool = None + + self.norm1 = norm_layer(dim) self.attn = MultiScaleAttention( dim, dim_out, num_heads=num_heads, - q_pool=self.pool, + q_pool=q_pool, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() @@ -176,6 +155,16 @@ class MultiScaleBlock(nn.Module): if dim != dim_out: self.proj = nn.Linear(dim, dim_out) + else: + self.proj = nn.Identity() + self.pool = None + if self.q_stride: + # note make a different instance for this Module so that it's not shared with attn module + self.pool = nn.MaxPool2d( + kernel_size=q_stride, + stride=q_stride, + ceil_mode=False, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x # B, H, W, C @@ -183,29 +172,32 @@ class MultiScaleBlock(nn.Module): # Skip connection if self.dim != self.dim_out: - shortcut = do_pool(self.proj(x), self.pool) + shortcut = self.proj(x) + if self.pool is not None: + shortcut = shortcut.permute(0, 3, 1, 2) + shortcut = self.pool(shortcut).permute(0, 2, 3, 1) # Window partition window_size = self.window_size - feat_size = x.shape[1:3] - pad_hw = 0, 0 - if window_size > 0: - x, pad_hw = window_partition(x, window_size) + H, W = x.shape[1:3] + Hp, Wp = H, W # keep torchscript happy + if self.is_windowed: + x = window_partition(x, window_size) + Hp, Wp, pad_h, pad_w = _calc_pad(H, W, window_size) + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) # Window Attention + Q Pooling (if stage change) x = self.attn(x) - if self.q_stride: + if self.q_stride is not None: # Shapes have changed due to Q pooling - window_size = self.window_size // self.q_stride[0] - feat_size = shortcut.shape[1:3] - - pad_h = (window_size - feat_size[0] % window_size) % window_size - pad_w = (window_size - feat_size[1] % window_size) % window_size - pad_hw = (feat_size[0] + pad_h, feat_size[1] + pad_w) + window_size = (self.window_size[0] // self.q_stride[0], self.window_size[1] // self.q_stride[1]) + H, W = shortcut.shape[1:3] + Hp, Wp, pad_h, pad_w = _calc_pad(H, W, window_size) # Reverse window partition - if self.window_size > 0: - x = window_unpartition(x, window_size, pad_hw, feat_size) + if self.is_windowed: + x = window_unpartition(x, window_size, (Hp, Wp)) + x = x[:, :H, :W, :].contiguous() # unpad x = shortcut + self.drop_path(x) @@ -364,12 +356,6 @@ class HieraDet(nn.Module): self.feature_info += [ dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')] - self.channel_list = ( - [self.blocks[i].dim_out for i in self.stage_ends[::-1]] - if True else - [self.blocks[-1].dim_out] - ) - self.num_features = self.head_hidden_size = embed_dim self.head = ClNormMlpClassifierHead( embed_dim, @@ -509,7 +495,7 @@ class HieraDet(nn.Module): self.head.reset(0, reset_other=True) return take_indices - def forward_features(self, x: torch.Tensor) -> List[torch.Tensor]: + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) # BHWC x = self._pos_embed(x) for i, blk in enumerate(self.blocks): From 0b05122cdad4240c4eac7263a517a530d0987a6b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 16 Aug 2024 14:33:40 -0700 Subject: [PATCH 08/17] Fixing hieradet (sam2) tests --- tests/test_models.py | 4 ++-- timm/models/hieradet_sam2.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 15e6cc35..fd09ceb2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -52,14 +52,14 @@ FEAT_INTER_FILTERS = [ 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', - 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', + 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2' ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', - 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', + 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*', 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*', ] NUM_NON_STD = len(NON_STD_FILTERS) diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 4f81b8b2..464e3313 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -45,7 +45,7 @@ def window_unpartition(windows: torch.Tensor, window_size: Tuple[int, int], hw: """ H, W = hw B = windows.shape[0] // (H * W // window_size[0] // window_size[1]) - x = windows.view(B, H // window_size[0], W // window_size[0], window_size[0], window_size[1], -1) + x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x @@ -567,11 +567,12 @@ def checkpoint_filter_fn(state_dict, model=None, prefix=''): def _create_hiera_det(variant: str, pretrained: bool = False, **kwargs) -> HieraDet: out_indices = kwargs.pop('out_indices', 4) - if True: # kwargs.get('pretrained_cfg', '') == '?': + checkpoint_prefix = '' + if 'sam2' in variant: # SAM2 pretrained weights have no classifier or final norm-layer (`head.norm`) # This is workaround loading with num_classes=0 w/o removing norm-layer. kwargs.setdefault('pretrained_strict', False) - checkpoint_prefix = 'image_encoder.trunk.' if 'sam2' in variant else '' + checkpoint_prefix = 'image_encoder.trunk.' return build_model_with_cfg( HieraDet, variant, From de3a91a7a0f1098b515c0bbe24b178467a6f2d70 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 16 Aug 2024 15:13:56 -0700 Subject: [PATCH 09/17] Add min_input_size of 128 for hieradet/sam2 --- timm/models/hieradet_sam2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 464e3313..04cb1484 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -517,7 +517,7 @@ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 0, 'input_size': (3, 896, 896), 'pool_size': (28, 28), - 'crop_pct': 1.0, 'interpolation': 'bicubic', + 'crop_pct': 1.0, 'interpolation': 'bicubic', 'min_input_size': (3, 128, 128), 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', **kwargs From 691bb54443944eddaf6c02938b4f5c77ba1fb6ee Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 16 Aug 2024 17:09:19 -0700 Subject: [PATCH 10/17] Larger min input size needed --- timm/models/hieradet_sam2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 04cb1484..2a40ba0c 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -517,7 +517,7 @@ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 0, 'input_size': (3, 896, 896), 'pool_size': (28, 28), - 'crop_pct': 1.0, 'interpolation': 'bicubic', 'min_input_size': (3, 128, 128), + 'crop_pct': 1.0, 'interpolation': 'bicubic', 'min_input_size': (3, 224, 224), 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', **kwargs @@ -545,6 +545,7 @@ default_cfgs = generate_default_cfgs({ "sam2_hiera_large": _cfg( hf_hub_id='facebook/sam2-hiera-large', hf_hub_filename='sam2_hiera_large.pt', + min_input_size=(3, 256, 256), input_size=(3, 1024, 1024), pool_size=(32, 32), ), }) From 1bd92bca0eb30b861937b5a02d41febed14480bf Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 16 Aug 2024 22:57:49 -0700 Subject: [PATCH 11/17] Add fused_attn flag to HieraDet attn block --- timm/models/hieradet_sam2.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 2a40ba0c..652a59e0 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -9,7 +9,7 @@ from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, PatchDropout, \ - get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple + get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -58,6 +58,8 @@ def _calc_pad(H: int, W: int, window_size: Tuple[int, int]) -> Tuple[int, int, i class MultiScaleAttention(nn.Module): + fused_attn: torch.jit.Final[bool] + def __init__( self, dim: int, @@ -66,13 +68,12 @@ class MultiScaleAttention(nn.Module): q_pool: nn.Module = None, ): super().__init__() - self.dim = dim self.dim_out = dim_out - self.num_heads = num_heads head_dim = dim_out // num_heads - self.scale = head_dim**-0.5 + self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() self.q_pool = q_pool self.qkv = nn.Linear(dim, dim_out * 3) @@ -95,11 +96,17 @@ class MultiScaleAttention(nn.Module): q = q.reshape(B, H * W, self.num_heads, -1) # Torch's SDPA expects [B, nheads, H*W, C] so we transpose - x = F.scaled_dot_product_attention( - q.transpose(1, 2), - k.transpose(1, 2), - v.transpose(1, 2), - ) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v) + else: + q = q * self.scale + attn = q @ k.transpose(-1, -2) + attn = attn.softmax(dim=-1) + x = attn @ v + # Transpose back x = x.transpose(1, 2).reshape(B, H, W, -1) From 7d83749207ee258a3e5fbec225b534dd0af1ccfd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 17 Aug 2024 08:27:13 -0700 Subject: [PATCH 12/17] pool size test fixes --- timm/models/efficientnet.py | 2 +- timm/models/hieradet_sam2.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 09d6c66c..2cf4130d 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -1294,7 +1294,7 @@ default_cfgs = generate_default_cfgs({ 'efficientnet_b1.ra4_e3600_r240_in1k': _cfg( hf_hub_id='timm/', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, - input_size=(3, 240, 240), crop_pct=0.9, + input_size=(3, 240, 240), crop_pct=0.9, pool_size=(8, 8), test_input_size=(3, 288, 288), test_crop_pct=1.0), 'efficientnet_b1.ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 652a59e0..4d57f5c3 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -294,6 +294,7 @@ class HieraDet(nn.Module): assert len(stages) == len(window_spec) self.num_classes = num_classes self.window_spec = window_spec + self.output_fmt = 'NHWC' depth = sum(stages) self.q_stride = q_stride From a256e5045711d32d1691feb417df7189a158b04e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 17 Aug 2024 11:22:53 -0700 Subject: [PATCH 13/17] Move padding back in front of windowing --- timm/models/hieradet_sam2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 4d57f5c3..4380c5e2 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -189,9 +189,9 @@ class MultiScaleBlock(nn.Module): H, W = x.shape[1:3] Hp, Wp = H, W # keep torchscript happy if self.is_windowed: - x = window_partition(x, window_size) Hp, Wp, pad_h, pad_w = _calc_pad(H, W, window_size) x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + x = window_partition(x, window_size) # Window Attention + Q Pooling (if stage change) x = self.attn(x) From dc94cca0e5fc89c9562137b4e59be6aded6db4fb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 21 Aug 2024 10:06:27 -0700 Subject: [PATCH 14/17] Remaining Hiera sbb weights uploaded --- timm/models/hiera.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 78d32752..34d6670f 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -881,11 +881,24 @@ default_cfgs = generate_default_cfgs({ num_classes=0, ), + "hiera_small_abswin_256.sbb2_e200_in12k_ft_in1k": _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95, + ), + "hiera_small_abswin_256.sbb2_pd_e200_in12k_ft_in1k": _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95, + ), "hiera_small_abswin_256.sbb2_e200_in12k": _cfg( hf_hub_id='timm/', num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95, ), + "hiera_small_abswin_256.sbb2_pd_e200_in12k": _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95, + ), "hiera_base_abswin_256.untrained": _cfg( # hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95, From 9fcbf39cdcd90ee25a4abd9802249e1542cbc3c1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 21 Aug 2024 10:09:38 -0700 Subject: [PATCH 15/17] Add remaining sbb vit betwixt/mediumd fine-tunes --- timm/models/vision_transformer.py | 34 +++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index afb5e002..26e3de5c 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1964,6 +1964,9 @@ default_cfgs = { hf_hub_id='timm/', num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95), + 'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), @@ -1975,9 +1978,15 @@ default_cfgs = { hf_hub_id='timm/', num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95), + 'vit_mediumd_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), 'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), @@ -1992,6 +2001,9 @@ default_cfgs = { hf_hub_id='timm/', num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch16_reg4_gap_256.untrained': _cfg( input_size=(3, 256, 256)), @@ -3118,6 +3130,17 @@ def vit_mediumd_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visi return model +@register_model +def vit_mediumd_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_mediumd_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( @@ -3140,6 +3163,17 @@ def vit_betwixt_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visi return model +@register_model +def vit_betwixt_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_betwixt_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( From 47e6958263a870c362dc88d1a29868e0c1ce0f57 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 21 Aug 2024 11:05:54 -0700 Subject: [PATCH 16/17] Add hierdet_small (non sam) model def --- timm/models/hieradet_sam2.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 4380c5e2..6b69d583 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -500,7 +500,7 @@ class HieraDet(nn.Module): take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_head: - self.head.reset(0, reset_other=True) + self.head.reset(0, reset_other=prune_norm) return take_indices def forward_features(self, x: torch.Tensor) -> torch.Tensor: @@ -556,6 +556,10 @@ default_cfgs = generate_default_cfgs({ min_input_size=(3, 256, 256), input_size=(3, 1024, 1024), pool_size=(32, 32), ), + "hieradet_small.untrained": _cfg( + num_classes=1000, + input_size=(3, 256, 256), pool_size=(8, 8), + ), }) @@ -604,12 +608,6 @@ def sam2_hiera_small(pretrained=False, **kwargs): return _create_hiera_det('sam2_hiera_small', pretrained=pretrained, **dict(model_args, **kwargs)) -# @register_model -# def sam2_hiera_base(pretrained=False, **kwargs): -# model_args = dict() -# return _create_hiera_det('sam2_hiera_base', pretrained=pretrained, **dict(model_args, **kwargs)) - - @register_model def sam2_hiera_base_plus(pretrained=False, **kwargs): model_args = dict(embed_dim=112, num_heads=2, global_pos_size=(14, 14)) @@ -626,3 +624,15 @@ def sam2_hiera_large(pretrained=False, **kwargs): window_spec=(8, 4, 16, 8), ) return _create_hiera_det('sam2_hiera_large', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def hieradet_small(pretrained=False, **kwargs): + model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8)) + return _create_hiera_det('hieradet_small', pretrained=pretrained, **dict(model_args, **kwargs)) + + +# @register_model +# def hieradet_base(pretrained=False, **kwargs): +# model_args = dict(window_spec=(8, 4, 16, 8)) +# return _create_hiera_det('hieradet_base', pretrained=pretrained, **dict(model_args, **kwargs)) From 17923a66bb27f902c68cc5108d7de0258edee4c4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 21 Aug 2024 11:23:39 -0700 Subject: [PATCH 17/17] Add layer scale to hieradet --- timm/models/hieradet_sam2.py | 57 +++++++++++++++++------------------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 6b69d583..d9585a52 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -1,4 +1,5 @@ import math +from copy import deepcopy from functools import partial from typing import Callable, Dict, List, Optional, Tuple, Union @@ -8,7 +9,7 @@ import torch.nn.functional as F from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, PatchDropout, \ +from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, LayerScale, \ get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn from ._builder import build_model_with_cfg @@ -121,11 +122,12 @@ class MultiScaleBlock(nn.Module): dim_out: int, num_heads: int, mlp_ratio: float = 4.0, - drop_path: float = 0.0, q_stride: Optional[Tuple[int, int]] = None, norm_layer: Union[nn.Module, str] = "LayerNorm", act_layer: Union[nn.Module, str] = "GELU", window_size: int = 0, + init_values: Optional[float] = None, + drop_path: float = 0.0, ): super().__init__() norm_layer = get_norm_layer(norm_layer) @@ -135,30 +137,6 @@ class MultiScaleBlock(nn.Module): self.dim = dim self.dim_out = dim_out self.q_stride = q_stride - if self.q_stride: - q_pool = nn.MaxPool2d( - kernel_size=q_stride, - stride=q_stride, - ceil_mode=False, - ) - else: - q_pool = None - - self.norm1 = norm_layer(dim) - self.attn = MultiScaleAttention( - dim, - dim_out, - num_heads=num_heads, - q_pool=q_pool, - ) - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - self.norm2 = norm_layer(dim_out) - self.mlp = Mlp( - dim_out, - int(dim_out * mlp_ratio), - act_layer=act_layer, - ) if dim != dim_out: self.proj = nn.Linear(dim, dim_out) @@ -173,6 +151,25 @@ class MultiScaleBlock(nn.Module): ceil_mode=False, ) + self.norm1 = norm_layer(dim) + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=deepcopy(self.pool), + ) + self.ls1 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = Mlp( + dim_out, + int(dim_out * mlp_ratio), + act_layer=act_layer, + ) + self.ls2 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x # B, H, W, C x = self.norm1(x) @@ -206,9 +203,8 @@ class MultiScaleBlock(nn.Module): x = window_unpartition(x, window_size, (Hp, Wp)) x = x[:, :H, :W, :].contiguous() # unpad - x = shortcut + self.drop_path(x) - - x = x + self.drop_path(self.mlp(self.norm2(x))) + x = shortcut + self.drop_path1(self.ls1(x)) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x @@ -280,6 +276,7 @@ class HieraDet(nn.Module): 16, 20, ), + init_values: Optional[float] = None, weight_init: str = '', fix_init: bool = True, head_init_scale: float = 0.001, @@ -628,7 +625,7 @@ def sam2_hiera_large(pretrained=False, **kwargs): @register_model def hieradet_small(pretrained=False, **kwargs): - model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8)) + model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8), init_values=1e-5) return _create_hiera_det('hieradet_small', pretrained=pretrained, **dict(model_args, **kwargs))