Add layer scale to hieradet
parent
47e6958263
commit
17923a66bb
timm/models
|
@ -1,4 +1,5 @@
|
|||
import math
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -8,7 +9,7 @@ import torch.nn.functional as F
|
|||
from torch.jit import Final
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, PatchDropout, \
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, LayerScale, \
|
||||
get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
|
@ -121,11 +122,12 @@ class MultiScaleBlock(nn.Module):
|
|||
dim_out: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
drop_path: float = 0.0,
|
||||
q_stride: Optional[Tuple[int, int]] = None,
|
||||
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
||||
act_layer: Union[nn.Module, str] = "GELU",
|
||||
window_size: int = 0,
|
||||
init_values: Optional[float] = None,
|
||||
drop_path: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
|
@ -135,30 +137,6 @@ class MultiScaleBlock(nn.Module):
|
|||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
self.q_stride = q_stride
|
||||
if self.q_stride:
|
||||
q_pool = nn.MaxPool2d(
|
||||
kernel_size=q_stride,
|
||||
stride=q_stride,
|
||||
ceil_mode=False,
|
||||
)
|
||||
else:
|
||||
q_pool = None
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = MultiScaleAttention(
|
||||
dim,
|
||||
dim_out,
|
||||
num_heads=num_heads,
|
||||
q_pool=q_pool,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim_out)
|
||||
self.mlp = Mlp(
|
||||
dim_out,
|
||||
int(dim_out * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
)
|
||||
|
||||
if dim != dim_out:
|
||||
self.proj = nn.Linear(dim, dim_out)
|
||||
|
@ -173,6 +151,25 @@ class MultiScaleBlock(nn.Module):
|
|||
ceil_mode=False,
|
||||
)
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = MultiScaleAttention(
|
||||
dim,
|
||||
dim_out,
|
||||
num_heads=num_heads,
|
||||
q_pool=deepcopy(self.pool),
|
||||
)
|
||||
self.ls1 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim_out)
|
||||
self.mlp = Mlp(
|
||||
dim_out,
|
||||
int(dim_out * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
)
|
||||
self.ls2 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shortcut = x # B, H, W, C
|
||||
x = self.norm1(x)
|
||||
|
@ -206,9 +203,8 @@ class MultiScaleBlock(nn.Module):
|
|||
x = window_unpartition(x, window_size, (Hp, Wp))
|
||||
x = x[:, :H, :W, :].contiguous() # unpad
|
||||
|
||||
x = shortcut + self.drop_path(x)
|
||||
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
x = shortcut + self.drop_path1(self.ls1(x))
|
||||
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
||||
return x
|
||||
|
||||
|
||||
|
@ -280,6 +276,7 @@ class HieraDet(nn.Module):
|
|||
16,
|
||||
20,
|
||||
),
|
||||
init_values: Optional[float] = None,
|
||||
weight_init: str = '',
|
||||
fix_init: bool = True,
|
||||
head_init_scale: float = 0.001,
|
||||
|
@ -628,7 +625,7 @@ def sam2_hiera_large(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def hieradet_small(pretrained=False, **kwargs):
|
||||
model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8))
|
||||
model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8), init_values=1e-5)
|
||||
return _create_hiera_det('hieradet_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue