Fix hiera init with num_classes=0, fix weight tag names for sbb2 hiera/vit weights, add LayerScale/LayerScale2d to layers

sbb2_vit_hiera_weights
Ross Wightman 2024-08-15 11:14:38 -07:00
parent fee91fdd41
commit 2f3fed43b8
4 changed files with 49 additions and 28 deletions

View File

@ -29,6 +29,7 @@ from .grid import ndgrid, meshgrid
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from .hybrid_embed import HybridEmbed, HybridEmbedWithSize
from .inplace_abn import InplaceAbn
from .layer_scale import LayerScale, LayerScale2d
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp

View File

@ -0,0 +1,38 @@
import torch
from torch import nn
class LayerScale(nn.Module):
""" LayerScale on tensors with channels in last-dim.
"""
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class LayerScale2d(nn.Module):
""" LayerScale for tensors with torch 2D NCHW layout.
"""
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
gamma = self.gamma.view(1, -1, 1, 1)
return x.mul_(gamma) if self.inplace else x * gamma

View File

@ -31,10 +31,8 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer, to_2tuple
from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple
from ._registry import generate_default_cfgs, register_model
from ._builder import build_model_with_cfg
@ -289,7 +287,6 @@ class MaskUnitAttention(nn.Module):
""" Input should be of shape [batch, tokens, channels]. """
B, N, _ = x.shape
num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
qkv = self.qkv(x).reshape(B, -1, num_windows, 3, self.heads, self.head_dim).permute(3, 0, 4, 2, 1, 5)
q, k, v = qkv.unbind(0)
@ -310,21 +307,6 @@ class MaskUnitAttention(nn.Module):
return x
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class HieraBlock(nn.Module):
def __init__(
self,
@ -342,7 +324,6 @@ class HieraBlock(nn.Module):
use_mask_unit_attn: bool = False,
):
super().__init__()
self.dim = dim
self.dim_out = dim_out
@ -631,10 +612,8 @@ class Hiera(nn.Module):
nn.init.trunc_normal_(self.pos_embed_win, std=0.02)
if weight_init != 'skip':
if weight_init == 'jax':
named_apply(partial(_init_weight_jax, head_bias=-math.log(self.num_classes)), self)
else:
named_apply(_init_weight_vit, self)
init_fn = _init_weight_jax if weight_init == 'jax' else _init_weight_vit
named_apply(init_fn, self)
if fix_init:
self.fix_init_weight()
if isinstance(self.head.fc, nn.Linear):
@ -868,11 +847,13 @@ def _init_weight_vit(module, name, init_bias=0.02, head_bias=0.):
nn.init.trunc_normal_(module.weight, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.constant_(module.bias, init_bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def _init_weight_jax(module, name, head_bias=0.):
if isinstance(module, nn.Linear):
if name.startswith('head'):
if name.startswith('head.fc'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
@ -960,7 +941,7 @@ default_cfgs = generate_default_cfgs({
num_classes=0,
),
"hiera_small_abswin_256.sbb2_ep200_in12k": _cfg(
"hiera_small_abswin_256.sbb2_e200_in12k": _cfg(
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95,
@ -1007,6 +988,7 @@ def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera:
**kwargs,
)
@register_model
def hiera_tiny_224(pretrained=False, **kwargs):
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2))

View File

@ -1967,7 +1967,7 @@ default_cfgs = {
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_mediumd_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg(
'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95),
@ -1984,7 +1984,7 @@ default_cfgs = {
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg(
'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95),