Rename global pos embed for Hiera abswin, factor out commonly used vit weight init fns to layers. Add a channels-last ver of normmlp head.

This commit is contained in:
Ross Wightman 2024-08-15 17:46:36 -07:00
parent 2f3fed43b8
commit a50e53d41f
4 changed files with 141 additions and 56 deletions

View File

@ -5,7 +5,7 @@ from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttention
from .attention_pool import AttentionPoolLatent
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
from .blur_pool import BlurPool2d, create_aa
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn
@ -57,4 +57,5 @@ from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2d
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .trace_utils import _assert, _float_to_int
from .typing import LayerType, PadType
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_, \
init_weight_jax, init_weight_vit

View File

@ -134,7 +134,8 @@ class ClassifierHead(nn.Module):
class NormMlpClassifierHead(nn.Module):
""" A Pool -> Norm -> Mlp Classifier Head for '2D' NCHW tensors
"""
def __init__(
self,
in_features: int,
@ -204,3 +205,76 @@ class NormMlpClassifierHead(nn.Module):
return x
x = self.fc(x)
return x
class ClNormMlpClassifierHead(nn.Module):
""" A Pool -> Norm -> Mlp Classifier Head for n-D NxxC tensors
"""
def __init__(
self,
in_features: int,
num_classes: int,
hidden_size: Optional[int] = None,
pool_type: str = 'avg',
drop_rate: float = 0.,
norm_layer: Union[str, Callable] = 'layernorm',
act_layer: Union[str, Callable] = 'gelu',
input_fmt: str = 'NHWC',
):
"""
Args:
in_features: The number of input features.
num_classes: The number of classes for the final classifier layer (output).
hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
pool_type: Global pooling type, pooling disabled if empty string ('').
drop_rate: Pre-classifier dropout rate.
norm_layer: Normalization layer type.
act_layer: MLP activation layer type (only used if hidden_size is not None).
"""
super().__init__()
self.in_features = in_features
self.hidden_size = hidden_size
self.num_features = in_features
assert pool_type in ('', 'avg', 'max', 'avgmax')
self.pool_type = pool_type
assert input_fmt in ('NHWC', 'NLC')
self.pool_dim = 1 if input_fmt == 'NLC' else (1, 2)
norm_layer = get_norm_layer(norm_layer)
act_layer = get_act_layer(act_layer)
self.norm = norm_layer(in_features)
if hidden_size:
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(in_features, hidden_size)),
('act', act_layer()),
]))
self.num_features = hidden_size
else:
self.pre_logits = nn.Identity()
self.drop = nn.Dropout(drop_rate)
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes: int, pool_type: Optional[str] = None):
if pool_type is not None:
self.pool_type = pool_type
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def _global_pool(self, x):
if self.pool_type:
if self.pool_type == 'avg':
x = x.mean(dim=self.pool_dim)
elif self.pool_type == 'max':
x = x.amax(dim=self.pool_dim)
elif self.pool_type == 'avgmax':
x = 0.5 * (x.amax(dim=self.pool_dim) + x.mean(dim=self.pool_dim))
return x
def forward(self, x, pre_logits: bool = False):
x = self._global_pool(x)
x = self.norm(x)
x = self.pre_logits(x)
x = self.drop(x)
if pre_logits:
return x
x = self.fc(x)
return x

View File

@ -1,7 +1,7 @@
import torch
import math
import warnings
from torch import nn
from torch.nn.init import _calculate_fan_in_and_fan_out
@ -123,3 +123,45 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
def lecun_normal_(tensor):
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
def init_weight_vit(
module: nn.Module,
name: str,
init_bias: float = 0.02,
head_bias: float = 0.,
classifier_name: str = 'head'
):
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
if name.startswith(classifier_name):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
nn.init.trunc_normal_(module.weight, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.constant_(module.bias, init_bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def init_weight_jax(
module: nn.Module,
name: str,
head_bias: float = 0.,
classifier_name: str = 'head',
):
if isinstance(module, nn.Linear):
if name.startswith(classifier_name):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()

View File

@ -32,7 +32,8 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple
from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple, \
init_weight_vit, init_weight_jax
from ._registry import generate_default_cfgs, register_model
from ._builder import build_model_with_cfg
@ -391,7 +392,7 @@ class NormClassifierHead(nn.Module):
self.pool_type = pool_type
self.norm = norm_layer(in_features)
self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity()
self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity()
self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False):
if pool_type is not None:
@ -400,7 +401,7 @@ class NormClassifierHead(nn.Module):
if other:
# reset other non-fc layers
self.norm = nn.Identity()
self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
if self.pool_type == 'avg':
@ -486,7 +487,7 @@ class Hiera(nn.Module):
head_init_scale: float = 0.001,
sep_pos_embed: bool = False,
abs_win_pos_embed: bool = False,
abs_pos_size: Tuple[int, int] = (14, 14),
global_pos_size: Tuple[int, int] = (14, 14),
):
super().__init__()
self.num_classes = num_classes
@ -513,11 +514,9 @@ class Hiera(nn.Module):
patch_kernel,
patch_stride,
patch_padding,
#reshape=False, # leave spatial / temporal dims in output
)
self.pos_embed: Optional[nn.Parameter] = None
self.pos_embed_abs: Optional[nn.Parameter] = None
self.pos_embed_win: Optional[nn.Parameter] = None
self.pos_embed_spatial: Optional[nn.Parameter] = None
self.pos_embed_temporal: Optional[nn.Parameter] = None
@ -531,7 +530,7 @@ class Hiera(nn.Module):
else:
if abs_win_pos_embed:
# absolute win, params NCHW to make tile & interpolate more natural before add & reshape
self.pos_embed_abs = nn.Parameter(torch.zeros(1, embed_dim, *abs_pos_size))
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *global_pos_size))
self.pos_embed_win = nn.Parameter(torch.zeros(1, embed_dim, *mask_unit_size))
else:
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
@ -555,7 +554,7 @@ class Hiera(nn.Module):
# Transformer blocks
cur_stage = 0
depth = sum(stages)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList()
self.feature_info = []
for i in range(depth):
@ -607,12 +606,12 @@ class Hiera(nn.Module):
else:
if self.pos_embed is not None:
nn.init.trunc_normal_(self.pos_embed, std=0.02)
elif self.pos_embed_abs is not None:
nn.init.trunc_normal_(self.pos_embed_abs, std=0.02)
if self.pos_embed_win is not None:
nn.init.trunc_normal_(self.pos_embed_win, std=0.02)
if weight_init != 'skip':
init_fn = _init_weight_jax if weight_init == 'jax' else _init_weight_vit
init_fn = init_weight_jax if weight_init == 'jax' else init_weight_vit
init_fn = partial(init_fn, classifier_name='head.fc')
named_apply(init_fn, self)
if fix_init:
self.fix_init_weight()
@ -681,20 +680,20 @@ class Hiera(nn.Module):
return mask.bool()
def _pos_embed(self, x) -> torch.Tensor:
if self.pos_embed is not None:
pos_embed = self.pos_embed
elif self.pos_embed_abs is not None:
if self.pos_embed_win is not None:
# absolute win position embedding, from
# Window Attention is Bugged: How not to Interpolate Position Embeddings (https://arxiv.org/abs/2311.05613)
pos_embed_win = self.pos_embed_win.tile(self.mask_spatial_shape)
pos_embed_abs = F.interpolate(
self.pos_embed_abs,
pos_embed = F.interpolate(
self.pos_embed,
size=pos_embed_win.shape[-2:],
mode='bicubic',
antialias=True,
)
pos_embed = pos_embed_abs + pos_embed_win
pos_embed = pos_embed + pos_embed_win
pos_embed = pos_embed.flatten(2).transpose(1, 2)
elif self.pos_embed is not None:
pos_embed = self.pos_embed
else:
pos_embed = (
self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1)
@ -838,37 +837,6 @@ class Hiera(nn.Module):
return x
def _init_weight_vit(module, name, init_bias=0.02, head_bias=0.):
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
if name.startswith('head.fc'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
nn.init.trunc_normal_(module.weight, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.constant_(module.bias, init_bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def _init_weight_jax(module, name, head_bias=0.):
if isinstance(module, nn.Linear):
if name.startswith('head.fc'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
from timm.models.layers import lecun_normal_
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def _cfg(url='', **kwargs):
return {
'url': url,
@ -972,6 +940,8 @@ def checkpoint_filter_fn(state_dict, model=None):
k = k.replace('encoder_norm.', 'head.norm.')
elif k.startswith('norm.'):
k = k.replace('norm.', 'head.norm.')
if k == 'pos_embed_abs':
k = 'pos_embed'
output[k] = v
return output
@ -1028,7 +998,7 @@ def hiera_huge_224(pretrained=False, **kwargs):
@register_model
def hiera_small_abswin_256(pretrained=False, **kwargs):
model_args = dict(
embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16),
embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, global_pos_size=(16, 16),
init_values=1e-5, weight_init='jax', use_expand_proj=False,
)
return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1037,7 +1007,5 @@ def hiera_small_abswin_256(pretrained=False, **kwargs):
@register_model
def hiera_base_abswin_256(pretrained=False, **kwargs):
model_args = dict(
embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16),
init_values=1e-5, weight_init='jax',
)
embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, init_values=1e-5, weight_init='jax')
return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))