mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update Hiera model for abswin, more stable weight init, layer-scale. ImageNet-12k weights for hiera_small_abswin, and two of the sbb vits with improved reg4 init.
This commit is contained in:
parent
ac3470188b
commit
fee91fdd41
@ -40,6 +40,7 @@ from ._registry import generate_default_cfgs, register_model
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply
|
||||
|
||||
|
||||
__all__ = ['Hiera']
|
||||
@ -309,6 +310,21 @@ class MaskUnitAttention(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
init_values: float = 1e-5,
|
||||
inplace: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
|
||||
|
||||
class HieraBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -317,6 +333,7 @@ class HieraBlock(nn.Module):
|
||||
heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
drop_path: float = 0.0,
|
||||
init_values: Optional[float] = None,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
q_stride: int = 1,
|
||||
@ -348,13 +365,14 @@ class HieraBlock(nn.Module):
|
||||
window_size,
|
||||
use_mask_unit_attn
|
||||
)
|
||||
self.ls1 = LayerScale(dim_out, init_values=init_values) if init_values is not None else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim_out)
|
||||
self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer)
|
||||
self.ls2 = LayerScale(dim_out, init_values=init_values) if init_values is not None else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Attention + Q Pooling
|
||||
x_norm = self.norm1(x)
|
||||
@ -369,10 +387,10 @@ class HieraBlock(nn.Module):
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
x = x + self.drop_path1(self.attn(x_norm))
|
||||
x = x + self.drop_path1(self.ls1(self.attn(x_norm)))
|
||||
|
||||
# MLP
|
||||
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
||||
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
||||
return x
|
||||
|
||||
|
||||
@ -470,6 +488,7 @@ class Hiera(nn.Module):
|
||||
mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1)
|
||||
# mask_unit_attn: which stages use mask unit attention?
|
||||
mask_unit_attn: Tuple[bool, ...] = (True, True, False, False),
|
||||
use_expand_proj: bool = True,
|
||||
dim_mul: float = 2.0,
|
||||
head_mul: float = 2.0,
|
||||
patch_kernel: Tuple[int, ...] = (7, 7),
|
||||
@ -477,6 +496,9 @@ class Hiera(nn.Module):
|
||||
patch_padding: Tuple[int, ...] = (3, 3),
|
||||
mlp_ratio: float = 4.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
init_values: Optional[float] = None,
|
||||
fix_init: bool = True,
|
||||
weight_init: str = '',
|
||||
norm_layer: Union[str, nn.Module] = "LayerNorm",
|
||||
drop_rate: float = 0.0,
|
||||
patch_drop_rate: float = 0.0,
|
||||
@ -575,9 +597,11 @@ class Hiera(nn.Module):
|
||||
heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop_path=dpr[i],
|
||||
init_values=init_values,
|
||||
norm_layer=norm_layer,
|
||||
q_stride=(flat_q_stride if i in q_pool_blocks else 1),
|
||||
window_size=flat_mu_size,
|
||||
use_expand_proj=use_expand_proj,
|
||||
use_mask_unit_attn=use_mask_unit_attn,
|
||||
)
|
||||
embed_dim = dim_out
|
||||
@ -605,19 +629,25 @@ class Hiera(nn.Module):
|
||||
elif self.pos_embed_abs is not None:
|
||||
nn.init.trunc_normal_(self.pos_embed_abs, std=0.02)
|
||||
nn.init.trunc_normal_(self.pos_embed_win, std=0.02)
|
||||
self.apply(partial(self._init_weights))
|
||||
|
||||
if weight_init != 'skip':
|
||||
if weight_init == 'jax':
|
||||
named_apply(partial(_init_weight_jax, head_bias=-math.log(self.num_classes)), self)
|
||||
else:
|
||||
named_apply(_init_weight_vit, self)
|
||||
if fix_init:
|
||||
self.fix_init_weight()
|
||||
if isinstance(self.head.fc, nn.Linear):
|
||||
self.head.fc.weight.data.mul_(head_init_scale)
|
||||
self.head.fc.bias.data.mul_(head_init_scale)
|
||||
|
||||
def _init_weights(self, m, init_bias=0.02):
|
||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, init_bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, init_bias)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, _layer_id):
|
||||
param.div_(math.sqrt(2.0 * _layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
@ -829,6 +859,35 @@ class Hiera(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _init_weight_vit(module, name, init_bias=0.02, head_bias=0.):
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
if name.startswith('head.fc'):
|
||||
nn.init.zeros_(module.weight)
|
||||
nn.init.constant_(module.bias, head_bias)
|
||||
else:
|
||||
nn.init.trunc_normal_(module.weight, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
nn.init.constant_(module.bias, init_bias)
|
||||
|
||||
|
||||
def _init_weight_jax(module, name, head_bias=0.):
|
||||
if isinstance(module, nn.Linear):
|
||||
if name.startswith('head'):
|
||||
nn.init.zeros_(module.weight)
|
||||
nn.init.constant_(module.bias, head_bias)
|
||||
else:
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
from timm.models.layers import lecun_normal_
|
||||
lecun_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif hasattr(module, 'init_weights'):
|
||||
module.init_weights()
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
@ -901,8 +960,9 @@ default_cfgs = generate_default_cfgs({
|
||||
num_classes=0,
|
||||
),
|
||||
|
||||
"hiera_small_abswin_256.untrained": _cfg(
|
||||
#hf_hub_id='timm/',
|
||||
"hiera_small_abswin_256.sbb2_ep200_in12k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821,
|
||||
input_size=(3, 256, 256), crop_pct=0.95,
|
||||
),
|
||||
"hiera_base_abswin_256.untrained": _cfg(
|
||||
@ -985,11 +1045,17 @@ def hiera_huge_224(pretrained=False, **kwargs):
|
||||
|
||||
@register_model
|
||||
def hiera_small_abswin_256(pretrained=False, **kwargs):
|
||||
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16))
|
||||
model_args = dict(
|
||||
embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16),
|
||||
init_values=1e-5, weight_init='jax', use_expand_proj=False,
|
||||
)
|
||||
return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_base_abswin_256(pretrained=False, **kwargs):
|
||||
model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16))
|
||||
return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
model_args = dict(
|
||||
embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16),
|
||||
init_values=1e-5, weight_init='jax',
|
||||
)
|
||||
return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
@ -1967,6 +1967,10 @@ default_cfgs = {
|
||||
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_mediumd_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821,
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821,
|
||||
@ -1980,6 +1984,10 @@ default_cfgs = {
|
||||
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_betwixt_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821,
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821,
|
||||
|
Loading…
x
Reference in New Issue
Block a user