diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 200dd3e0..69af2f48 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -40,6 +40,7 @@ from ._registry import generate_default_cfgs, register_model from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function +from ._manipulate import named_apply __all__ = ['Hiera'] @@ -309,6 +310,21 @@ class MaskUnitAttention(nn.Module): return x +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + class HieraBlock(nn.Module): def __init__( self, @@ -317,6 +333,7 @@ class HieraBlock(nn.Module): heads: int, mlp_ratio: float = 4.0, drop_path: float = 0.0, + init_values: Optional[float] = None, norm_layer: nn.Module = nn.LayerNorm, act_layer: nn.Module = nn.GELU, q_stride: int = 1, @@ -348,13 +365,14 @@ class HieraBlock(nn.Module): window_size, use_mask_unit_attn ) + self.ls1 = LayerScale(dim_out, init_values=init_values) if init_values is not None else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity() self.norm2 = norm_layer(dim_out) self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer) + self.ls2 = LayerScale(dim_out, init_values=init_values) if init_values is not None else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity() - def forward(self, x: torch.Tensor) -> torch.Tensor: # Attention + Q Pooling x_norm = self.norm1(x) @@ -369,10 +387,10 @@ class HieraBlock(nn.Module): ], dim=-1, ) - x = x + self.drop_path1(self.attn(x_norm)) + x = x + self.drop_path1(self.ls1(self.attn(x_norm))) # MLP - x = x + self.drop_path2(self.mlp(self.norm2(x))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x @@ -470,6 +488,7 @@ class Hiera(nn.Module): mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1) # mask_unit_attn: which stages use mask unit attention? mask_unit_attn: Tuple[bool, ...] = (True, True, False, False), + use_expand_proj: bool = True, dim_mul: float = 2.0, head_mul: float = 2.0, patch_kernel: Tuple[int, ...] = (7, 7), @@ -477,6 +496,9 @@ class Hiera(nn.Module): patch_padding: Tuple[int, ...] = (3, 3), mlp_ratio: float = 4.0, drop_path_rate: float = 0.0, + init_values: Optional[float] = None, + fix_init: bool = True, + weight_init: str = '', norm_layer: Union[str, nn.Module] = "LayerNorm", drop_rate: float = 0.0, patch_drop_rate: float = 0.0, @@ -575,9 +597,11 @@ class Hiera(nn.Module): heads=num_heads, mlp_ratio=mlp_ratio, drop_path=dpr[i], + init_values=init_values, norm_layer=norm_layer, q_stride=(flat_q_stride if i in q_pool_blocks else 1), window_size=flat_mu_size, + use_expand_proj=use_expand_proj, use_mask_unit_attn=use_mask_unit_attn, ) embed_dim = dim_out @@ -605,19 +629,25 @@ class Hiera(nn.Module): elif self.pos_embed_abs is not None: nn.init.trunc_normal_(self.pos_embed_abs, std=0.02) nn.init.trunc_normal_(self.pos_embed_win, std=0.02) - self.apply(partial(self._init_weights)) + + if weight_init != 'skip': + if weight_init == 'jax': + named_apply(partial(_init_weight_jax, head_bias=-math.log(self.num_classes)), self) + else: + named_apply(_init_weight_vit, self) + if fix_init: + self.fix_init_weight() if isinstance(self.head.fc, nn.Linear): self.head.fc.weight.data.mul_(head_init_scale) self.head.fc.bias.data.mul_(head_init_scale) - def _init_weights(self, m, init_bias=0.02): - if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): - nn.init.trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, init_bias) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, init_bias) - nn.init.constant_(m.weight, 1.0) + def fix_init_weight(self): + def rescale(param, _layer_id): + param.div_(math.sqrt(2.0 * _layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) @torch.jit.ignore def no_weight_decay(self): @@ -829,6 +859,35 @@ class Hiera(nn.Module): return x +def _init_weight_vit(module, name, init_bias=0.02, head_bias=0.): + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): + if name.startswith('head.fc'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.trunc_normal_(module.weight, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + nn.init.constant_(module.bias, init_bias) + + +def _init_weight_jax(module, name, head_bias=0.): + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + from timm.models.layers import lecun_normal_ + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + def _cfg(url='', **kwargs): return { 'url': url, @@ -901,8 +960,9 @@ default_cfgs = generate_default_cfgs({ num_classes=0, ), - "hiera_small_abswin_256.untrained": _cfg( - #hf_hub_id='timm/', + "hiera_small_abswin_256.sbb2_ep200_in12k": _cfg( + hf_hub_id='timm/', + num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95, ), "hiera_base_abswin_256.untrained": _cfg( @@ -985,11 +1045,17 @@ def hiera_huge_224(pretrained=False, **kwargs): @register_model def hiera_small_abswin_256(pretrained=False, **kwargs): - model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16)) + model_args = dict( + embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16), + init_values=1e-5, weight_init='jax', use_expand_proj=False, + ) return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def hiera_base_abswin_256(pretrained=False, **kwargs): - model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16)) - return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs)) \ No newline at end of file + model_args = dict( + embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16), + init_values=1e-5, weight_init='jax', + ) + return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b613683f..64cab9ee 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1967,6 +1967,10 @@ default_cfgs = { 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_mediumd_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg( hf_hub_id='timm/', num_classes=11821, @@ -1980,6 +1984,10 @@ default_cfgs = { 'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg( hf_hub_id='timm/', num_classes=11821,