Update Hiera model for abswin, more stable weight init, layer-scale. ImageNet-12k weights for hiera_small_abswin, and two of the sbb vits with improved reg4 init.

This commit is contained in:
Ross Wightman 2024-08-14 12:22:40 -07:00
parent ac3470188b
commit fee91fdd41
2 changed files with 91 additions and 17 deletions

View File

@ -40,6 +40,7 @@ from ._registry import generate_default_cfgs, register_model
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
__all__ = ['Hiera']
@ -309,6 +310,21 @@ class MaskUnitAttention(nn.Module):
return x
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class HieraBlock(nn.Module):
def __init__(
self,
@ -317,6 +333,7 @@ class HieraBlock(nn.Module):
heads: int,
mlp_ratio: float = 4.0,
drop_path: float = 0.0,
init_values: Optional[float] = None,
norm_layer: nn.Module = nn.LayerNorm,
act_layer: nn.Module = nn.GELU,
q_stride: int = 1,
@ -348,13 +365,14 @@ class HieraBlock(nn.Module):
window_size,
use_mask_unit_attn
)
self.ls1 = LayerScale(dim_out, init_values=init_values) if init_values is not None else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer)
self.ls2 = LayerScale(dim_out, init_values=init_values) if init_values is not None else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Attention + Q Pooling
x_norm = self.norm1(x)
@ -369,10 +387,10 @@ class HieraBlock(nn.Module):
],
dim=-1,
)
x = x + self.drop_path1(self.attn(x_norm))
x = x + self.drop_path1(self.ls1(self.attn(x_norm)))
# MLP
x = x + self.drop_path2(self.mlp(self.norm2(x)))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
@ -470,6 +488,7 @@ class Hiera(nn.Module):
mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1)
# mask_unit_attn: which stages use mask unit attention?
mask_unit_attn: Tuple[bool, ...] = (True, True, False, False),
use_expand_proj: bool = True,
dim_mul: float = 2.0,
head_mul: float = 2.0,
patch_kernel: Tuple[int, ...] = (7, 7),
@ -477,6 +496,9 @@ class Hiera(nn.Module):
patch_padding: Tuple[int, ...] = (3, 3),
mlp_ratio: float = 4.0,
drop_path_rate: float = 0.0,
init_values: Optional[float] = None,
fix_init: bool = True,
weight_init: str = '',
norm_layer: Union[str, nn.Module] = "LayerNorm",
drop_rate: float = 0.0,
patch_drop_rate: float = 0.0,
@ -575,9 +597,11 @@ class Hiera(nn.Module):
heads=num_heads,
mlp_ratio=mlp_ratio,
drop_path=dpr[i],
init_values=init_values,
norm_layer=norm_layer,
q_stride=(flat_q_stride if i in q_pool_blocks else 1),
window_size=flat_mu_size,
use_expand_proj=use_expand_proj,
use_mask_unit_attn=use_mask_unit_attn,
)
embed_dim = dim_out
@ -605,19 +629,25 @@ class Hiera(nn.Module):
elif self.pos_embed_abs is not None:
nn.init.trunc_normal_(self.pos_embed_abs, std=0.02)
nn.init.trunc_normal_(self.pos_embed_win, std=0.02)
self.apply(partial(self._init_weights))
if weight_init != 'skip':
if weight_init == 'jax':
named_apply(partial(_init_weight_jax, head_bias=-math.log(self.num_classes)), self)
else:
named_apply(_init_weight_vit, self)
if fix_init:
self.fix_init_weight()
if isinstance(self.head.fc, nn.Linear):
self.head.fc.weight.data.mul_(head_init_scale)
self.head.fc.bias.data.mul_(head_init_scale)
def _init_weights(self, m, init_bias=0.02):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, init_bias)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, init_bias)
nn.init.constant_(m.weight, 1.0)
def fix_init_weight(self):
def rescale(param, _layer_id):
param.div_(math.sqrt(2.0 * _layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
@torch.jit.ignore
def no_weight_decay(self):
@ -829,6 +859,35 @@ class Hiera(nn.Module):
return x
def _init_weight_vit(module, name, init_bias=0.02, head_bias=0.):
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
if name.startswith('head.fc'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
nn.init.trunc_normal_(module.weight, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.constant_(module.bias, init_bias)
def _init_weight_jax(module, name, head_bias=0.):
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
from timm.models.layers import lecun_normal_
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def _cfg(url='', **kwargs):
return {
'url': url,
@ -901,8 +960,9 @@ default_cfgs = generate_default_cfgs({
num_classes=0,
),
"hiera_small_abswin_256.untrained": _cfg(
#hf_hub_id='timm/',
"hiera_small_abswin_256.sbb2_ep200_in12k": _cfg(
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95,
),
"hiera_base_abswin_256.untrained": _cfg(
@ -985,11 +1045,17 @@ def hiera_huge_224(pretrained=False, **kwargs):
@register_model
def hiera_small_abswin_256(pretrained=False, **kwargs):
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16))
model_args = dict(
embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16),
init_values=1e-5, weight_init='jax', use_expand_proj=False,
)
return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def hiera_base_abswin_256(pretrained=False, **kwargs):
model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16))
return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
model_args = dict(
embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16),
init_values=1e-5, weight_init='jax',
)
return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))

View File

@ -1967,6 +1967,10 @@ default_cfgs = {
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_mediumd_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95),
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
@ -1980,6 +1984,10 @@ default_cfgs = {
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,