Fix hiera init with num_classes=0, fix weight tag names for sbb2 hiera/vit weights, add LayerScale/LayerScale2d to layers
parent
fee91fdd41
commit
2f3fed43b8
timm
|
@ -29,6 +29,7 @@ from .grid import ndgrid, meshgrid
|
|||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
|
||||
from .hybrid_embed import HybridEmbed, HybridEmbedWithSize
|
||||
from .inplace_abn import InplaceAbn
|
||||
from .layer_scale import LayerScale, LayerScale2d
|
||||
from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
""" LayerScale on tensors with channels in last-dim.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
init_values: float = 1e-5,
|
||||
inplace: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
|
||||
|
||||
class LayerScale2d(nn.Module):
|
||||
""" LayerScale for tensors with torch 2D NCHW layout.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
init_values: float = 1e-5,
|
||||
inplace: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
gamma = self.gamma.view(1, -1, 1, 1)
|
||||
return x.mul_(gamma) if self.inplace else x * gamma
|
||||
|
|
@ -31,10 +31,8 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer, to_2tuple
|
||||
|
||||
from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple
|
||||
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from ._builder import build_model_with_cfg
|
||||
|
@ -289,7 +287,6 @@ class MaskUnitAttention(nn.Module):
|
|||
""" Input should be of shape [batch, tokens, channels]. """
|
||||
B, N, _ = x.shape
|
||||
num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
|
||||
|
||||
qkv = self.qkv(x).reshape(B, -1, num_windows, 3, self.heads, self.head_dim).permute(3, 0, 4, 2, 1, 5)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
|
@ -310,21 +307,6 @@ class MaskUnitAttention(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
init_values: float = 1e-5,
|
||||
inplace: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
|
||||
|
||||
class HieraBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -342,7 +324,6 @@ class HieraBlock(nn.Module):
|
|||
use_mask_unit_attn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
|
||||
|
@ -631,10 +612,8 @@ class Hiera(nn.Module):
|
|||
nn.init.trunc_normal_(self.pos_embed_win, std=0.02)
|
||||
|
||||
if weight_init != 'skip':
|
||||
if weight_init == 'jax':
|
||||
named_apply(partial(_init_weight_jax, head_bias=-math.log(self.num_classes)), self)
|
||||
else:
|
||||
named_apply(_init_weight_vit, self)
|
||||
init_fn = _init_weight_jax if weight_init == 'jax' else _init_weight_vit
|
||||
named_apply(init_fn, self)
|
||||
if fix_init:
|
||||
self.fix_init_weight()
|
||||
if isinstance(self.head.fc, nn.Linear):
|
||||
|
@ -868,11 +847,13 @@ def _init_weight_vit(module, name, init_bias=0.02, head_bias=0.):
|
|||
nn.init.trunc_normal_(module.weight, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
nn.init.constant_(module.bias, init_bias)
|
||||
elif hasattr(module, 'init_weights'):
|
||||
module.init_weights()
|
||||
|
||||
|
||||
def _init_weight_jax(module, name, head_bias=0.):
|
||||
if isinstance(module, nn.Linear):
|
||||
if name.startswith('head'):
|
||||
if name.startswith('head.fc'):
|
||||
nn.init.zeros_(module.weight)
|
||||
nn.init.constant_(module.bias, head_bias)
|
||||
else:
|
||||
|
@ -960,7 +941,7 @@ default_cfgs = generate_default_cfgs({
|
|||
num_classes=0,
|
||||
),
|
||||
|
||||
"hiera_small_abswin_256.sbb2_ep200_in12k": _cfg(
|
||||
"hiera_small_abswin_256.sbb2_e200_in12k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821,
|
||||
input_size=(3, 256, 256), crop_pct=0.95,
|
||||
|
@ -1007,6 +988,7 @@ def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera:
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_tiny_224(pretrained=False, **kwargs):
|
||||
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2))
|
||||
|
|
|
@ -1967,7 +1967,7 @@ default_cfgs = {
|
|||
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_mediumd_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg(
|
||||
'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821,
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
|
@ -1984,7 +1984,7 @@ default_cfgs = {
|
|||
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_betwixt_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg(
|
||||
'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821,
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
|
|
Loading…
Reference in New Issue