Add layer scale to hieradet

This commit is contained in:
Ross Wightman 2024-08-21 11:23:39 -07:00
parent 47e6958263
commit 17923a66bb

View File

@ -1,4 +1,5 @@
import math import math
from copy import deepcopy
from functools import partial from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
@ -8,7 +9,7 @@ import torch.nn.functional as F
from torch.jit import Final from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, PatchDropout, \ from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, LayerScale, \
get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
@ -121,11 +122,12 @@ class MultiScaleBlock(nn.Module):
dim_out: int, dim_out: int,
num_heads: int, num_heads: int,
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
drop_path: float = 0.0,
q_stride: Optional[Tuple[int, int]] = None, q_stride: Optional[Tuple[int, int]] = None,
norm_layer: Union[nn.Module, str] = "LayerNorm", norm_layer: Union[nn.Module, str] = "LayerNorm",
act_layer: Union[nn.Module, str] = "GELU", act_layer: Union[nn.Module, str] = "GELU",
window_size: int = 0, window_size: int = 0,
init_values: Optional[float] = None,
drop_path: float = 0.0,
): ):
super().__init__() super().__init__()
norm_layer = get_norm_layer(norm_layer) norm_layer = get_norm_layer(norm_layer)
@ -135,30 +137,6 @@ class MultiScaleBlock(nn.Module):
self.dim = dim self.dim = dim
self.dim_out = dim_out self.dim_out = dim_out
self.q_stride = q_stride self.q_stride = q_stride
if self.q_stride:
q_pool = nn.MaxPool2d(
kernel_size=q_stride,
stride=q_stride,
ceil_mode=False,
)
else:
q_pool = None
self.norm1 = norm_layer(dim)
self.attn = MultiScaleAttention(
dim,
dim_out,
num_heads=num_heads,
q_pool=q_pool,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = Mlp(
dim_out,
int(dim_out * mlp_ratio),
act_layer=act_layer,
)
if dim != dim_out: if dim != dim_out:
self.proj = nn.Linear(dim, dim_out) self.proj = nn.Linear(dim, dim_out)
@ -173,6 +151,25 @@ class MultiScaleBlock(nn.Module):
ceil_mode=False, ceil_mode=False,
) )
self.norm1 = norm_layer(dim)
self.attn = MultiScaleAttention(
dim,
dim_out,
num_heads=num_heads,
q_pool=deepcopy(self.pool),
)
self.ls1 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = Mlp(
dim_out,
int(dim_out * mlp_ratio),
act_layer=act_layer,
)
self.ls2 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x # B, H, W, C shortcut = x # B, H, W, C
x = self.norm1(x) x = self.norm1(x)
@ -206,9 +203,8 @@ class MultiScaleBlock(nn.Module):
x = window_unpartition(x, window_size, (Hp, Wp)) x = window_unpartition(x, window_size, (Hp, Wp))
x = x[:, :H, :W, :].contiguous() # unpad x = x[:, :H, :W, :].contiguous() # unpad
x = shortcut + self.drop_path(x) x = shortcut + self.drop_path1(self.ls1(x))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x return x
@ -280,6 +276,7 @@ class HieraDet(nn.Module):
16, 16,
20, 20,
), ),
init_values: Optional[float] = None,
weight_init: str = '', weight_init: str = '',
fix_init: bool = True, fix_init: bool = True,
head_init_scale: float = 0.001, head_init_scale: float = 0.001,
@ -628,7 +625,7 @@ def sam2_hiera_large(pretrained=False, **kwargs):
@register_model @register_model
def hieradet_small(pretrained=False, **kwargs): def hieradet_small(pretrained=False, **kwargs):
model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8)) model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8), init_values=1e-5)
return _create_hiera_det('hieradet_small', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_hiera_det('hieradet_small', pretrained=pretrained, **dict(model_args, **kwargs))