mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Rename global pos embed for Hiera abswin, factor out commonly used vit weight init fns to layers. Add a channels-last ver of normmlp head.
This commit is contained in:
parent
2f3fed43b8
commit
a50e53d41f
@ -5,7 +5,7 @@ from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttention
|
||||
from .attention_pool import AttentionPoolLatent
|
||||
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||
from .blur_pool import BlurPool2d, create_aa
|
||||
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
|
||||
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
|
||||
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn
|
||||
@ -57,4 +57,5 @@ from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2d
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .trace_utils import _assert, _float_to_int
|
||||
from .typing import LayerType, PadType
|
||||
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
|
||||
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_, \
|
||||
init_weight_jax, init_weight_vit
|
||||
|
@ -134,7 +134,8 @@ class ClassifierHead(nn.Module):
|
||||
|
||||
|
||||
class NormMlpClassifierHead(nn.Module):
|
||||
|
||||
""" A Pool -> Norm -> Mlp Classifier Head for '2D' NCHW tensors
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
@ -204,3 +205,76 @@ class NormMlpClassifierHead(nn.Module):
|
||||
return x
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class ClNormMlpClassifierHead(nn.Module):
|
||||
""" A Pool -> Norm -> Mlp Classifier Head for n-D NxxC tensors
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
hidden_size: Optional[int] = None,
|
||||
pool_type: str = 'avg',
|
||||
drop_rate: float = 0.,
|
||||
norm_layer: Union[str, Callable] = 'layernorm',
|
||||
act_layer: Union[str, Callable] = 'gelu',
|
||||
input_fmt: str = 'NHWC',
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_features: The number of input features.
|
||||
num_classes: The number of classes for the final classifier layer (output).
|
||||
hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
|
||||
pool_type: Global pooling type, pooling disabled if empty string ('').
|
||||
drop_rate: Pre-classifier dropout rate.
|
||||
norm_layer: Normalization layer type.
|
||||
act_layer: MLP activation layer type (only used if hidden_size is not None).
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.hidden_size = hidden_size
|
||||
self.num_features = in_features
|
||||
assert pool_type in ('', 'avg', 'max', 'avgmax')
|
||||
self.pool_type = pool_type
|
||||
assert input_fmt in ('NHWC', 'NLC')
|
||||
self.pool_dim = 1 if input_fmt == 'NLC' else (1, 2)
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
act_layer = get_act_layer(act_layer)
|
||||
|
||||
self.norm = norm_layer(in_features)
|
||||
if hidden_size:
|
||||
self.pre_logits = nn.Sequential(OrderedDict([
|
||||
('fc', nn.Linear(in_features, hidden_size)),
|
||||
('act', act_layer()),
|
||||
]))
|
||||
self.num_features = hidden_size
|
||||
else:
|
||||
self.pre_logits = nn.Identity()
|
||||
self.drop = nn.Dropout(drop_rate)
|
||||
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def reset(self, num_classes: int, pool_type: Optional[str] = None):
|
||||
if pool_type is not None:
|
||||
self.pool_type = pool_type
|
||||
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def _global_pool(self, x):
|
||||
if self.pool_type:
|
||||
if self.pool_type == 'avg':
|
||||
x = x.mean(dim=self.pool_dim)
|
||||
elif self.pool_type == 'max':
|
||||
x = x.amax(dim=self.pool_dim)
|
||||
elif self.pool_type == 'avgmax':
|
||||
x = 0.5 * (x.amax(dim=self.pool_dim) + x.mean(dim=self.pool_dim))
|
||||
return x
|
||||
|
||||
def forward(self, x, pre_logits: bool = False):
|
||||
x = self._global_pool(x)
|
||||
x = self.norm(x)
|
||||
x = self.pre_logits(x)
|
||||
x = self.drop(x)
|
||||
if pre_logits:
|
||||
return x
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import math
|
||||
import warnings
|
||||
|
||||
from torch import nn
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
|
||||
@ -123,3 +123,45 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
||||
|
||||
def lecun_normal_(tensor):
|
||||
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
|
||||
|
||||
|
||||
def init_weight_vit(
|
||||
module: nn.Module,
|
||||
name: str,
|
||||
init_bias: float = 0.02,
|
||||
head_bias: float = 0.,
|
||||
classifier_name: str = 'head'
|
||||
):
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
if name.startswith(classifier_name):
|
||||
nn.init.zeros_(module.weight)
|
||||
nn.init.constant_(module.bias, head_bias)
|
||||
else:
|
||||
nn.init.trunc_normal_(module.weight, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
nn.init.constant_(module.bias, init_bias)
|
||||
elif hasattr(module, 'init_weights'):
|
||||
module.init_weights()
|
||||
|
||||
|
||||
def init_weight_jax(
|
||||
module: nn.Module,
|
||||
name: str,
|
||||
head_bias: float = 0.,
|
||||
classifier_name: str = 'head',
|
||||
):
|
||||
if isinstance(module, nn.Linear):
|
||||
if name.startswith(classifier_name):
|
||||
nn.init.zeros_(module.weight)
|
||||
nn.init.constant_(module.bias, head_bias)
|
||||
else:
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
lecun_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif hasattr(module, 'init_weights'):
|
||||
module.init_weights()
|
||||
|
||||
|
@ -32,7 +32,8 @@ import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple
|
||||
from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple, \
|
||||
init_weight_vit, init_weight_jax
|
||||
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from ._builder import build_model_with_cfg
|
||||
@ -391,7 +392,7 @@ class NormClassifierHead(nn.Module):
|
||||
self.pool_type = pool_type
|
||||
self.norm = norm_layer(in_features)
|
||||
self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity()
|
||||
self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False):
|
||||
if pool_type is not None:
|
||||
@ -400,7 +401,7 @@ class NormClassifierHead(nn.Module):
|
||||
if other:
|
||||
# reset other non-fc layers
|
||||
self.norm = nn.Identity()
|
||||
self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||
if self.pool_type == 'avg':
|
||||
@ -486,7 +487,7 @@ class Hiera(nn.Module):
|
||||
head_init_scale: float = 0.001,
|
||||
sep_pos_embed: bool = False,
|
||||
abs_win_pos_embed: bool = False,
|
||||
abs_pos_size: Tuple[int, int] = (14, 14),
|
||||
global_pos_size: Tuple[int, int] = (14, 14),
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
@ -513,11 +514,9 @@ class Hiera(nn.Module):
|
||||
patch_kernel,
|
||||
patch_stride,
|
||||
patch_padding,
|
||||
#reshape=False, # leave spatial / temporal dims in output
|
||||
)
|
||||
|
||||
self.pos_embed: Optional[nn.Parameter] = None
|
||||
self.pos_embed_abs: Optional[nn.Parameter] = None
|
||||
self.pos_embed_win: Optional[nn.Parameter] = None
|
||||
self.pos_embed_spatial: Optional[nn.Parameter] = None
|
||||
self.pos_embed_temporal: Optional[nn.Parameter] = None
|
||||
@ -531,7 +530,7 @@ class Hiera(nn.Module):
|
||||
else:
|
||||
if abs_win_pos_embed:
|
||||
# absolute win, params NCHW to make tile & interpolate more natural before add & reshape
|
||||
self.pos_embed_abs = nn.Parameter(torch.zeros(1, embed_dim, *abs_pos_size))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *global_pos_size))
|
||||
self.pos_embed_win = nn.Parameter(torch.zeros(1, embed_dim, *mask_unit_size))
|
||||
else:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
|
||||
@ -555,7 +554,7 @@ class Hiera(nn.Module):
|
||||
# Transformer blocks
|
||||
cur_stage = 0
|
||||
depth = sum(stages)
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList()
|
||||
self.feature_info = []
|
||||
for i in range(depth):
|
||||
@ -607,12 +606,12 @@ class Hiera(nn.Module):
|
||||
else:
|
||||
if self.pos_embed is not None:
|
||||
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
||||
elif self.pos_embed_abs is not None:
|
||||
nn.init.trunc_normal_(self.pos_embed_abs, std=0.02)
|
||||
if self.pos_embed_win is not None:
|
||||
nn.init.trunc_normal_(self.pos_embed_win, std=0.02)
|
||||
|
||||
if weight_init != 'skip':
|
||||
init_fn = _init_weight_jax if weight_init == 'jax' else _init_weight_vit
|
||||
init_fn = init_weight_jax if weight_init == 'jax' else init_weight_vit
|
||||
init_fn = partial(init_fn, classifier_name='head.fc')
|
||||
named_apply(init_fn, self)
|
||||
if fix_init:
|
||||
self.fix_init_weight()
|
||||
@ -681,20 +680,20 @@ class Hiera(nn.Module):
|
||||
return mask.bool()
|
||||
|
||||
def _pos_embed(self, x) -> torch.Tensor:
|
||||
if self.pos_embed is not None:
|
||||
pos_embed = self.pos_embed
|
||||
elif self.pos_embed_abs is not None:
|
||||
if self.pos_embed_win is not None:
|
||||
# absolute win position embedding, from
|
||||
# Window Attention is Bugged: How not to Interpolate Position Embeddings (https://arxiv.org/abs/2311.05613)
|
||||
pos_embed_win = self.pos_embed_win.tile(self.mask_spatial_shape)
|
||||
pos_embed_abs = F.interpolate(
|
||||
self.pos_embed_abs,
|
||||
pos_embed = F.interpolate(
|
||||
self.pos_embed,
|
||||
size=pos_embed_win.shape[-2:],
|
||||
mode='bicubic',
|
||||
antialias=True,
|
||||
)
|
||||
pos_embed = pos_embed_abs + pos_embed_win
|
||||
pos_embed = pos_embed + pos_embed_win
|
||||
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
||||
elif self.pos_embed is not None:
|
||||
pos_embed = self.pos_embed
|
||||
else:
|
||||
pos_embed = (
|
||||
self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1)
|
||||
@ -838,37 +837,6 @@ class Hiera(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _init_weight_vit(module, name, init_bias=0.02, head_bias=0.):
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
if name.startswith('head.fc'):
|
||||
nn.init.zeros_(module.weight)
|
||||
nn.init.constant_(module.bias, head_bias)
|
||||
else:
|
||||
nn.init.trunc_normal_(module.weight, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
nn.init.constant_(module.bias, init_bias)
|
||||
elif hasattr(module, 'init_weights'):
|
||||
module.init_weights()
|
||||
|
||||
|
||||
def _init_weight_jax(module, name, head_bias=0.):
|
||||
if isinstance(module, nn.Linear):
|
||||
if name.startswith('head.fc'):
|
||||
nn.init.zeros_(module.weight)
|
||||
nn.init.constant_(module.bias, head_bias)
|
||||
else:
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
from timm.models.layers import lecun_normal_
|
||||
lecun_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif hasattr(module, 'init_weights'):
|
||||
module.init_weights()
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
@ -972,6 +940,8 @@ def checkpoint_filter_fn(state_dict, model=None):
|
||||
k = k.replace('encoder_norm.', 'head.norm.')
|
||||
elif k.startswith('norm.'):
|
||||
k = k.replace('norm.', 'head.norm.')
|
||||
if k == 'pos_embed_abs':
|
||||
k = 'pos_embed'
|
||||
output[k] = v
|
||||
return output
|
||||
|
||||
@ -1028,7 +998,7 @@ def hiera_huge_224(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def hiera_small_abswin_256(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16),
|
||||
embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, global_pos_size=(16, 16),
|
||||
init_values=1e-5, weight_init='jax', use_expand_proj=False,
|
||||
)
|
||||
return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
@ -1037,7 +1007,5 @@ def hiera_small_abswin_256(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def hiera_base_abswin_256(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16),
|
||||
init_values=1e-5, weight_init='jax',
|
||||
)
|
||||
embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, init_values=1e-5, weight_init='jax')
|
||||
return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
Loading…
x
Reference in New Issue
Block a user