mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add layer scale to hieradet
This commit is contained in:
parent
47e6958263
commit
17923a66bb
@ -1,4 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -8,7 +9,7 @@ import torch.nn.functional as F
|
|||||||
from torch.jit import Final
|
from torch.jit import Final
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, PatchDropout, \
|
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, LayerScale, \
|
||||||
get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn
|
get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn
|
||||||
|
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
@ -121,11 +122,12 @@ class MultiScaleBlock(nn.Module):
|
|||||||
dim_out: int,
|
dim_out: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
drop_path: float = 0.0,
|
|
||||||
q_stride: Optional[Tuple[int, int]] = None,
|
q_stride: Optional[Tuple[int, int]] = None,
|
||||||
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
||||||
act_layer: Union[nn.Module, str] = "GELU",
|
act_layer: Union[nn.Module, str] = "GELU",
|
||||||
window_size: int = 0,
|
window_size: int = 0,
|
||||||
|
init_values: Optional[float] = None,
|
||||||
|
drop_path: float = 0.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
norm_layer = get_norm_layer(norm_layer)
|
norm_layer = get_norm_layer(norm_layer)
|
||||||
@ -135,30 +137,6 @@ class MultiScaleBlock(nn.Module):
|
|||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.dim_out = dim_out
|
self.dim_out = dim_out
|
||||||
self.q_stride = q_stride
|
self.q_stride = q_stride
|
||||||
if self.q_stride:
|
|
||||||
q_pool = nn.MaxPool2d(
|
|
||||||
kernel_size=q_stride,
|
|
||||||
stride=q_stride,
|
|
||||||
ceil_mode=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
q_pool = None
|
|
||||||
|
|
||||||
self.norm1 = norm_layer(dim)
|
|
||||||
self.attn = MultiScaleAttention(
|
|
||||||
dim,
|
|
||||||
dim_out,
|
|
||||||
num_heads=num_heads,
|
|
||||||
q_pool=q_pool,
|
|
||||||
)
|
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
||||||
|
|
||||||
self.norm2 = norm_layer(dim_out)
|
|
||||||
self.mlp = Mlp(
|
|
||||||
dim_out,
|
|
||||||
int(dim_out * mlp_ratio),
|
|
||||||
act_layer=act_layer,
|
|
||||||
)
|
|
||||||
|
|
||||||
if dim != dim_out:
|
if dim != dim_out:
|
||||||
self.proj = nn.Linear(dim, dim_out)
|
self.proj = nn.Linear(dim, dim_out)
|
||||||
@ -173,6 +151,25 @@ class MultiScaleBlock(nn.Module):
|
|||||||
ceil_mode=False,
|
ceil_mode=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = MultiScaleAttention(
|
||||||
|
dim,
|
||||||
|
dim_out,
|
||||||
|
num_heads=num_heads,
|
||||||
|
q_pool=deepcopy(self.pool),
|
||||||
|
)
|
||||||
|
self.ls1 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity()
|
||||||
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
self.norm2 = norm_layer(dim_out)
|
||||||
|
self.mlp = Mlp(
|
||||||
|
dim_out,
|
||||||
|
int(dim_out * mlp_ratio),
|
||||||
|
act_layer=act_layer,
|
||||||
|
)
|
||||||
|
self.ls2 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity()
|
||||||
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
shortcut = x # B, H, W, C
|
shortcut = x # B, H, W, C
|
||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
@ -206,9 +203,8 @@ class MultiScaleBlock(nn.Module):
|
|||||||
x = window_unpartition(x, window_size, (Hp, Wp))
|
x = window_unpartition(x, window_size, (Hp, Wp))
|
||||||
x = x[:, :H, :W, :].contiguous() # unpad
|
x = x[:, :H, :W, :].contiguous() # unpad
|
||||||
|
|
||||||
x = shortcut + self.drop_path(x)
|
x = shortcut + self.drop_path1(self.ls1(x))
|
||||||
|
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
||||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -280,6 +276,7 @@ class HieraDet(nn.Module):
|
|||||||
16,
|
16,
|
||||||
20,
|
20,
|
||||||
),
|
),
|
||||||
|
init_values: Optional[float] = None,
|
||||||
weight_init: str = '',
|
weight_init: str = '',
|
||||||
fix_init: bool = True,
|
fix_init: bool = True,
|
||||||
head_init_scale: float = 0.001,
|
head_init_scale: float = 0.001,
|
||||||
@ -628,7 +625,7 @@ def sam2_hiera_large(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def hieradet_small(pretrained=False, **kwargs):
|
def hieradet_small(pretrained=False, **kwargs):
|
||||||
model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8))
|
model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8), init_values=1e-5)
|
||||||
return _create_hiera_det('hieradet_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_hiera_det('hieradet_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user