Add layer scale to hieradet

sbb2_vit_hiera_weights
Ross Wightman 2024-08-21 11:23:39 -07:00
parent 47e6958263
commit 17923a66bb
1 changed files with 27 additions and 30 deletions
timm/models

View File

@ -1,4 +1,5 @@
import math
from copy import deepcopy
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Union
@ -8,7 +9,7 @@ import torch.nn.functional as F
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, PatchDropout, \
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, LayerScale, \
get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn
from ._builder import build_model_with_cfg
@ -121,11 +122,12 @@ class MultiScaleBlock(nn.Module):
dim_out: int,
num_heads: int,
mlp_ratio: float = 4.0,
drop_path: float = 0.0,
q_stride: Optional[Tuple[int, int]] = None,
norm_layer: Union[nn.Module, str] = "LayerNorm",
act_layer: Union[nn.Module, str] = "GELU",
window_size: int = 0,
init_values: Optional[float] = None,
drop_path: float = 0.0,
):
super().__init__()
norm_layer = get_norm_layer(norm_layer)
@ -135,30 +137,6 @@ class MultiScaleBlock(nn.Module):
self.dim = dim
self.dim_out = dim_out
self.q_stride = q_stride
if self.q_stride:
q_pool = nn.MaxPool2d(
kernel_size=q_stride,
stride=q_stride,
ceil_mode=False,
)
else:
q_pool = None
self.norm1 = norm_layer(dim)
self.attn = MultiScaleAttention(
dim,
dim_out,
num_heads=num_heads,
q_pool=q_pool,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = Mlp(
dim_out,
int(dim_out * mlp_ratio),
act_layer=act_layer,
)
if dim != dim_out:
self.proj = nn.Linear(dim, dim_out)
@ -173,6 +151,25 @@ class MultiScaleBlock(nn.Module):
ceil_mode=False,
)
self.norm1 = norm_layer(dim)
self.attn = MultiScaleAttention(
dim,
dim_out,
num_heads=num_heads,
q_pool=deepcopy(self.pool),
)
self.ls1 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = Mlp(
dim_out,
int(dim_out * mlp_ratio),
act_layer=act_layer,
)
self.ls2 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x # B, H, W, C
x = self.norm1(x)
@ -206,9 +203,8 @@ class MultiScaleBlock(nn.Module):
x = window_unpartition(x, window_size, (Hp, Wp))
x = x[:, :H, :W, :].contiguous() # unpad
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = shortcut + self.drop_path1(self.ls1(x))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
@ -280,6 +276,7 @@ class HieraDet(nn.Module):
16,
20,
),
init_values: Optional[float] = None,
weight_init: str = '',
fix_init: bool = True,
head_init_scale: float = 0.001,
@ -628,7 +625,7 @@ def sam2_hiera_large(pretrained=False, **kwargs):
@register_model
def hieradet_small(pretrained=False, **kwargs):
model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8))
model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8), init_values=1e-5)
return _create_hiera_det('hieradet_small', pretrained=pretrained, **dict(model_args, **kwargs))