From a50e53d41fa937c0baaa6e9318fde252e3e8b389 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 15 Aug 2024 17:46:36 -0700 Subject: [PATCH] Rename global pos embed for Hiera abswin, factor out commonly used vit weight init fns to layers. Add a channels-last ver of normmlp head. --- timm/layers/__init__.py | 5 ++- timm/layers/classifier.py | 76 +++++++++++++++++++++++++++++++++++++- timm/layers/weight_init.py | 44 +++++++++++++++++++++- timm/models/hiera.py | 72 ++++++++++-------------------------- 4 files changed, 141 insertions(+), 56 deletions(-) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 61115589..49ffa0ce 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -5,7 +5,7 @@ from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttention from .attention_pool import AttentionPoolLatent from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding from .blur_pool import BlurPool2d, create_aa -from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead +from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead from .cond_conv2d import CondConv2d, get_condconv_initializer from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \ set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn @@ -57,4 +57,5 @@ from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2d from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .trace_utils import _assert, _float_to_int from .typing import LayerType, PadType -from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ +from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_, \ + init_weight_jax, init_weight_vit diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 2441c050..1cbab683 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -134,7 +134,8 @@ class ClassifierHead(nn.Module): class NormMlpClassifierHead(nn.Module): - + """ A Pool -> Norm -> Mlp Classifier Head for '2D' NCHW tensors + """ def __init__( self, in_features: int, @@ -204,3 +205,76 @@ class NormMlpClassifierHead(nn.Module): return x x = self.fc(x) return x + + +class ClNormMlpClassifierHead(nn.Module): + """ A Pool -> Norm -> Mlp Classifier Head for n-D NxxC tensors + """ + def __init__( + self, + in_features: int, + num_classes: int, + hidden_size: Optional[int] = None, + pool_type: str = 'avg', + drop_rate: float = 0., + norm_layer: Union[str, Callable] = 'layernorm', + act_layer: Union[str, Callable] = 'gelu', + input_fmt: str = 'NHWC', + ): + """ + Args: + in_features: The number of input features. + num_classes: The number of classes for the final classifier layer (output). + hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None. + pool_type: Global pooling type, pooling disabled if empty string (''). + drop_rate: Pre-classifier dropout rate. + norm_layer: Normalization layer type. + act_layer: MLP activation layer type (only used if hidden_size is not None). + """ + super().__init__() + self.in_features = in_features + self.hidden_size = hidden_size + self.num_features = in_features + assert pool_type in ('', 'avg', 'max', 'avgmax') + self.pool_type = pool_type + assert input_fmt in ('NHWC', 'NLC') + self.pool_dim = 1 if input_fmt == 'NLC' else (1, 2) + norm_layer = get_norm_layer(norm_layer) + act_layer = get_act_layer(act_layer) + + self.norm = norm_layer(in_features) + if hidden_size: + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(in_features, hidden_size)), + ('act', act_layer()), + ])) + self.num_features = hidden_size + else: + self.pre_logits = nn.Identity() + self.drop = nn.Dropout(drop_rate) + self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def reset(self, num_classes: int, pool_type: Optional[str] = None): + if pool_type is not None: + self.pool_type = pool_type + self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def _global_pool(self, x): + if self.pool_type: + if self.pool_type == 'avg': + x = x.mean(dim=self.pool_dim) + elif self.pool_type == 'max': + x = x.amax(dim=self.pool_dim) + elif self.pool_type == 'avgmax': + x = 0.5 * (x.amax(dim=self.pool_dim) + x.mean(dim=self.pool_dim)) + return x + + def forward(self, x, pre_logits: bool = False): + x = self._global_pool(x) + x = self.norm(x) + x = self.pre_logits(x) + x = self.drop(x) + if pre_logits: + return x + x = self.fc(x) + return x diff --git a/timm/layers/weight_init.py b/timm/layers/weight_init.py index 943e4f4c..d1127ecb 100644 --- a/timm/layers/weight_init.py +++ b/timm/layers/weight_init.py @@ -1,7 +1,7 @@ import torch import math import warnings - +from torch import nn from torch.nn.init import _calculate_fan_in_and_fan_out @@ -123,3 +123,45 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): def lecun_normal_(tensor): variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') + + +def init_weight_vit( + module: nn.Module, + name: str, + init_bias: float = 0.02, + head_bias: float = 0., + classifier_name: str = 'head' +): + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): + if name.startswith(classifier_name): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.trunc_normal_(module.weight, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + nn.init.constant_(module.bias, init_bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weight_jax( + module: nn.Module, + name: str, + head_bias: float = 0., + classifier_name: str = 'head', +): + if isinstance(module, nn.Linear): + if name.startswith(classifier_name): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 808053e9..ec5d8b7b 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -32,7 +32,8 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple +from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple, \ + init_weight_vit, init_weight_jax from ._registry import generate_default_cfgs, register_model from ._builder import build_model_with_cfg @@ -391,7 +392,7 @@ class NormClassifierHead(nn.Module): self.pool_type = pool_type self.norm = norm_layer(in_features) self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity() - self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity() + self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity() def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False): if pool_type is not None: @@ -400,7 +401,7 @@ class NormClassifierHead(nn.Module): if other: # reset other non-fc layers self.norm = nn.Identity() - self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity() + self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity() def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: if self.pool_type == 'avg': @@ -486,7 +487,7 @@ class Hiera(nn.Module): head_init_scale: float = 0.001, sep_pos_embed: bool = False, abs_win_pos_embed: bool = False, - abs_pos_size: Tuple[int, int] = (14, 14), + global_pos_size: Tuple[int, int] = (14, 14), ): super().__init__() self.num_classes = num_classes @@ -513,11 +514,9 @@ class Hiera(nn.Module): patch_kernel, patch_stride, patch_padding, - #reshape=False, # leave spatial / temporal dims in output ) self.pos_embed: Optional[nn.Parameter] = None - self.pos_embed_abs: Optional[nn.Parameter] = None self.pos_embed_win: Optional[nn.Parameter] = None self.pos_embed_spatial: Optional[nn.Parameter] = None self.pos_embed_temporal: Optional[nn.Parameter] = None @@ -531,7 +530,7 @@ class Hiera(nn.Module): else: if abs_win_pos_embed: # absolute win, params NCHW to make tile & interpolate more natural before add & reshape - self.pos_embed_abs = nn.Parameter(torch.zeros(1, embed_dim, *abs_pos_size)) + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *global_pos_size)) self.pos_embed_win = nn.Parameter(torch.zeros(1, embed_dim, *mask_unit_size)) else: self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim)) @@ -555,7 +554,7 @@ class Hiera(nn.Module): # Transformer blocks cur_stage = 0 depth = sum(stages) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList() self.feature_info = [] for i in range(depth): @@ -607,12 +606,12 @@ class Hiera(nn.Module): else: if self.pos_embed is not None: nn.init.trunc_normal_(self.pos_embed, std=0.02) - elif self.pos_embed_abs is not None: - nn.init.trunc_normal_(self.pos_embed_abs, std=0.02) + if self.pos_embed_win is not None: nn.init.trunc_normal_(self.pos_embed_win, std=0.02) if weight_init != 'skip': - init_fn = _init_weight_jax if weight_init == 'jax' else _init_weight_vit + init_fn = init_weight_jax if weight_init == 'jax' else init_weight_vit + init_fn = partial(init_fn, classifier_name='head.fc') named_apply(init_fn, self) if fix_init: self.fix_init_weight() @@ -681,20 +680,20 @@ class Hiera(nn.Module): return mask.bool() def _pos_embed(self, x) -> torch.Tensor: - if self.pos_embed is not None: - pos_embed = self.pos_embed - elif self.pos_embed_abs is not None: + if self.pos_embed_win is not None: # absolute win position embedding, from # Window Attention is Bugged: How not to Interpolate Position Embeddings (https://arxiv.org/abs/2311.05613) pos_embed_win = self.pos_embed_win.tile(self.mask_spatial_shape) - pos_embed_abs = F.interpolate( - self.pos_embed_abs, + pos_embed = F.interpolate( + self.pos_embed, size=pos_embed_win.shape[-2:], mode='bicubic', antialias=True, ) - pos_embed = pos_embed_abs + pos_embed_win + pos_embed = pos_embed + pos_embed_win pos_embed = pos_embed.flatten(2).transpose(1, 2) + elif self.pos_embed is not None: + pos_embed = self.pos_embed else: pos_embed = ( self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1) @@ -838,37 +837,6 @@ class Hiera(nn.Module): return x -def _init_weight_vit(module, name, init_bias=0.02, head_bias=0.): - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): - if name.startswith('head.fc'): - nn.init.zeros_(module.weight) - nn.init.constant_(module.bias, head_bias) - else: - nn.init.trunc_normal_(module.weight, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: - nn.init.constant_(module.bias, init_bias) - elif hasattr(module, 'init_weights'): - module.init_weights() - - -def _init_weight_jax(module, name, head_bias=0.): - if isinstance(module, nn.Linear): - if name.startswith('head.fc'): - nn.init.zeros_(module.weight) - nn.init.constant_(module.bias, head_bias) - else: - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) - elif isinstance(module, nn.Conv2d): - from timm.models.layers import lecun_normal_ - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif hasattr(module, 'init_weights'): - module.init_weights() - - def _cfg(url='', **kwargs): return { 'url': url, @@ -972,6 +940,8 @@ def checkpoint_filter_fn(state_dict, model=None): k = k.replace('encoder_norm.', 'head.norm.') elif k.startswith('norm.'): k = k.replace('norm.', 'head.norm.') + if k == 'pos_embed_abs': + k = 'pos_embed' output[k] = v return output @@ -1028,7 +998,7 @@ def hiera_huge_224(pretrained=False, **kwargs): @register_model def hiera_small_abswin_256(pretrained=False, **kwargs): model_args = dict( - embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16), + embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, global_pos_size=(16, 16), init_values=1e-5, weight_init='jax', use_expand_proj=False, ) return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -1037,7 +1007,5 @@ def hiera_small_abswin_256(pretrained=False, **kwargs): @register_model def hiera_base_abswin_256(pretrained=False, **kwargs): model_args = dict( - embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16), - init_values=1e-5, weight_init='jax', - ) + embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, init_values=1e-5, weight_init='jax') return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))