mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2258 from huggingface/sbb2_vit_hiera_weights
Update Hiera model for abswin & add more in12k weights for hiera & vit
This commit is contained in:
commit
00c5be7656
@ -52,14 +52,14 @@ FEAT_INTER_FILTERS = [
|
|||||||
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
|
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
|
||||||
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
||||||
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
|
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
|
||||||
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit',
|
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2'
|
||||||
]
|
]
|
||||||
|
|
||||||
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
||||||
NON_STD_FILTERS = [
|
NON_STD_FILTERS = [
|
||||||
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
|
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
|
||||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
|
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
|
||||||
]
|
]
|
||||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||||
|
@ -5,7 +5,7 @@ from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttention
|
|||||||
from .attention_pool import AttentionPoolLatent
|
from .attention_pool import AttentionPoolLatent
|
||||||
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||||
from .blur_pool import BlurPool2d, create_aa
|
from .blur_pool import BlurPool2d, create_aa
|
||||||
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
|
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
|
||||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||||
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
|
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
|
||||||
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn
|
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn
|
||||||
@ -29,6 +29,7 @@ from .grid import ndgrid, meshgrid
|
|||||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
|
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
|
||||||
from .hybrid_embed import HybridEmbed, HybridEmbedWithSize
|
from .hybrid_embed import HybridEmbed, HybridEmbedWithSize
|
||||||
from .inplace_abn import InplaceAbn
|
from .inplace_abn import InplaceAbn
|
||||||
|
from .layer_scale import LayerScale, LayerScale2d
|
||||||
from .linear import Linear
|
from .linear import Linear
|
||||||
from .mixed_conv2d import MixedConv2d
|
from .mixed_conv2d import MixedConv2d
|
||||||
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
|
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
|
||||||
@ -56,4 +57,5 @@ from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2d
|
|||||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||||
from .trace_utils import _assert, _float_to_int
|
from .trace_utils import _assert, _float_to_int
|
||||||
from .typing import LayerType, PadType
|
from .typing import LayerType, PadType
|
||||||
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
|
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_, \
|
||||||
|
init_weight_jax, init_weight_vit
|
||||||
|
@ -134,7 +134,8 @@ class ClassifierHead(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class NormMlpClassifierHead(nn.Module):
|
class NormMlpClassifierHead(nn.Module):
|
||||||
|
""" A Pool -> Norm -> Mlp Classifier Head for '2D' NCHW tensors
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
@ -204,3 +205,79 @@ class NormMlpClassifierHead(nn.Module):
|
|||||||
return x
|
return x
|
||||||
x = self.fc(x)
|
x = self.fc(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ClNormMlpClassifierHead(nn.Module):
|
||||||
|
""" A Pool -> Norm -> Mlp Classifier Head for n-D NxxC tensors
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
num_classes: int,
|
||||||
|
hidden_size: Optional[int] = None,
|
||||||
|
pool_type: str = 'avg',
|
||||||
|
drop_rate: float = 0.,
|
||||||
|
norm_layer: Union[str, Callable] = 'layernorm',
|
||||||
|
act_layer: Union[str, Callable] = 'gelu',
|
||||||
|
input_fmt: str = 'NHWC',
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
in_features: The number of input features.
|
||||||
|
num_classes: The number of classes for the final classifier layer (output).
|
||||||
|
hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
|
||||||
|
pool_type: Global pooling type, pooling disabled if empty string ('').
|
||||||
|
drop_rate: Pre-classifier dropout rate.
|
||||||
|
norm_layer: Normalization layer type.
|
||||||
|
act_layer: MLP activation layer type (only used if hidden_size is not None).
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_features = in_features
|
||||||
|
assert pool_type in ('', 'avg', 'max', 'avgmax')
|
||||||
|
self.pool_type = pool_type
|
||||||
|
assert input_fmt in ('NHWC', 'NLC')
|
||||||
|
self.pool_dim = 1 if input_fmt == 'NLC' else (1, 2)
|
||||||
|
norm_layer = get_norm_layer(norm_layer)
|
||||||
|
act_layer = get_act_layer(act_layer)
|
||||||
|
|
||||||
|
self.norm = norm_layer(in_features)
|
||||||
|
if hidden_size:
|
||||||
|
self.pre_logits = nn.Sequential(OrderedDict([
|
||||||
|
('fc', nn.Linear(in_features, hidden_size)),
|
||||||
|
('act', act_layer()),
|
||||||
|
]))
|
||||||
|
self.num_features = hidden_size
|
||||||
|
else:
|
||||||
|
self.pre_logits = nn.Identity()
|
||||||
|
self.drop = nn.Dropout(drop_rate)
|
||||||
|
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
|
||||||
|
if pool_type is not None:
|
||||||
|
self.pool_type = pool_type
|
||||||
|
if reset_other:
|
||||||
|
self.pre_logits = nn.Identity()
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def _global_pool(self, x):
|
||||||
|
if self.pool_type:
|
||||||
|
if self.pool_type == 'avg':
|
||||||
|
x = x.mean(dim=self.pool_dim)
|
||||||
|
elif self.pool_type == 'max':
|
||||||
|
x = x.amax(dim=self.pool_dim)
|
||||||
|
elif self.pool_type == 'avgmax':
|
||||||
|
x = 0.5 * (x.amax(dim=self.pool_dim) + x.mean(dim=self.pool_dim))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, pre_logits: bool = False):
|
||||||
|
x = self._global_pool(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.pre_logits(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
if pre_logits:
|
||||||
|
return x
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
@ -97,6 +97,7 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
|
|||||||
return None
|
return None
|
||||||
if isinstance(name, Callable):
|
if isinstance(name, Callable):
|
||||||
return name
|
return name
|
||||||
|
name = name.lower()
|
||||||
if not (is_exportable() or is_scriptable()):
|
if not (is_exportable() or is_scriptable()):
|
||||||
# If not exporting or scripting the model, first look for a memory-efficient version with
|
# If not exporting or scripting the model, first look for a memory-efficient version with
|
||||||
# custom autograd, then fallback
|
# custom autograd, then fallback
|
||||||
@ -117,6 +118,7 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
|
|||||||
return name
|
return name
|
||||||
if not name:
|
if not name:
|
||||||
return None
|
return None
|
||||||
|
name = name.lower()
|
||||||
if not (is_exportable() or is_scriptable()):
|
if not (is_exportable() or is_scriptable()):
|
||||||
if name in _ACT_LAYER_ME:
|
if name in _ACT_LAYER_ME:
|
||||||
return _ACT_LAYER_ME[name]
|
return _ACT_LAYER_ME[name]
|
||||||
|
38
timm/layers/layer_scale.py
Normal file
38
timm/layers/layer_scale.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class LayerScale(nn.Module):
|
||||||
|
""" LayerScale on tensors with channels in last-dim.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
init_values: float = 1e-5,
|
||||||
|
inplace: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.inplace = inplace
|
||||||
|
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||||
|
|
||||||
|
|
||||||
|
class LayerScale2d(nn.Module):
|
||||||
|
""" LayerScale for tensors with torch 2D NCHW layout.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
init_values: float = 1e-5,
|
||||||
|
inplace: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.inplace = inplace
|
||||||
|
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gamma = self.gamma.view(1, -1, 1, 1)
|
||||||
|
return x.mul_(gamma) if self.inplace else x * gamma
|
||||||
|
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
|
from torch import nn
|
||||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||||
|
|
||||||
|
|
||||||
@ -123,3 +123,45 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
|||||||
|
|
||||||
def lecun_normal_(tensor):
|
def lecun_normal_(tensor):
|
||||||
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
|
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
|
||||||
|
|
||||||
|
|
||||||
|
def init_weight_vit(
|
||||||
|
module: nn.Module,
|
||||||
|
name: str,
|
||||||
|
init_bias: float = 0.02,
|
||||||
|
head_bias: float = 0.,
|
||||||
|
classifier_name: str = 'head'
|
||||||
|
):
|
||||||
|
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||||
|
if name.startswith(classifier_name):
|
||||||
|
nn.init.zeros_(module.weight)
|
||||||
|
nn.init.constant_(module.bias, head_bias)
|
||||||
|
else:
|
||||||
|
nn.init.trunc_normal_(module.weight, std=0.02)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
nn.init.constant_(module.bias, init_bias)
|
||||||
|
elif hasattr(module, 'init_weights'):
|
||||||
|
module.init_weights()
|
||||||
|
|
||||||
|
|
||||||
|
def init_weight_jax(
|
||||||
|
module: nn.Module,
|
||||||
|
name: str,
|
||||||
|
head_bias: float = 0.,
|
||||||
|
classifier_name: str = 'head',
|
||||||
|
):
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
if name.startswith(classifier_name):
|
||||||
|
nn.init.zeros_(module.weight)
|
||||||
|
nn.init.constant_(module.bias, head_bias)
|
||||||
|
else:
|
||||||
|
nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.Conv2d):
|
||||||
|
lecun_normal_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
elif hasattr(module, 'init_weights'):
|
||||||
|
module.init_weights()
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ from .ghostnet import *
|
|||||||
from .hardcorenas import *
|
from .hardcorenas import *
|
||||||
from .hgnet import *
|
from .hgnet import *
|
||||||
from .hiera import *
|
from .hiera import *
|
||||||
|
from .hieradet_sam2 import *
|
||||||
from .hrnet import *
|
from .hrnet import *
|
||||||
from .inception_next import *
|
from .inception_next import *
|
||||||
from .inception_resnet_v2 import *
|
from .inception_resnet_v2 import *
|
||||||
|
@ -1290,8 +1290,12 @@ default_cfgs = generate_default_cfgs({
|
|||||||
'efficientnet_b0.ra4_e3600_r224_in1k': _cfg(
|
'efficientnet_b0.ra4_e3600_r224_in1k': _cfg(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
|
||||||
crop_pct=0.9, test_input_size=(3, 256, 256), test_crop_pct=1.0
|
crop_pct=0.9, test_input_size=(3, 256, 256), test_crop_pct=1.0),
|
||||||
),
|
'efficientnet_b1.ra4_e3600_r240_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.9, pool_size=(8, 8),
|
||||||
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
'efficientnet_b1.ft_in1k': _cfg(
|
'efficientnet_b1.ft_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
|
@ -31,15 +31,15 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer, to_2tuple
|
from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \
|
||||||
|
_assert, get_norm_layer, to_2tuple, init_weight_vit, init_weight_jax
|
||||||
|
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features import feature_take_indices
|
from ._features import feature_take_indices
|
||||||
from ._features_fx import register_notrace_function
|
from ._features_fx import register_notrace_function
|
||||||
|
from ._manipulate import named_apply
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['Hiera']
|
__all__ = ['Hiera']
|
||||||
@ -288,7 +288,6 @@ class MaskUnitAttention(nn.Module):
|
|||||||
""" Input should be of shape [batch, tokens, channels]. """
|
""" Input should be of shape [batch, tokens, channels]. """
|
||||||
B, N, _ = x.shape
|
B, N, _ = x.shape
|
||||||
num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
|
num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
|
||||||
|
|
||||||
qkv = self.qkv(x).reshape(B, -1, num_windows, 3, self.heads, self.head_dim).permute(3, 0, 4, 2, 1, 5)
|
qkv = self.qkv(x).reshape(B, -1, num_windows, 3, self.heads, self.head_dim).permute(3, 0, 4, 2, 1, 5)
|
||||||
q, k, v = qkv.unbind(0)
|
q, k, v = qkv.unbind(0)
|
||||||
|
|
||||||
@ -317,6 +316,7 @@ class HieraBlock(nn.Module):
|
|||||||
heads: int,
|
heads: int,
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
drop_path: float = 0.0,
|
drop_path: float = 0.0,
|
||||||
|
init_values: Optional[float] = None,
|
||||||
norm_layer: nn.Module = nn.LayerNorm,
|
norm_layer: nn.Module = nn.LayerNorm,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: nn.Module = nn.GELU,
|
||||||
q_stride: int = 1,
|
q_stride: int = 1,
|
||||||
@ -325,7 +325,6 @@ class HieraBlock(nn.Module):
|
|||||||
use_mask_unit_attn: bool = False,
|
use_mask_unit_attn: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.dim_out = dim_out
|
self.dim_out = dim_out
|
||||||
|
|
||||||
@ -348,13 +347,14 @@ class HieraBlock(nn.Module):
|
|||||||
window_size,
|
window_size,
|
||||||
use_mask_unit_attn
|
use_mask_unit_attn
|
||||||
)
|
)
|
||||||
|
self.ls1 = LayerScale(dim_out, init_values=init_values) if init_values is not None else nn.Identity()
|
||||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
|
||||||
|
|
||||||
self.norm2 = norm_layer(dim_out)
|
self.norm2 = norm_layer(dim_out)
|
||||||
self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer)
|
self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer)
|
||||||
|
self.ls2 = LayerScale(dim_out, init_values=init_values) if init_values is not None else nn.Identity()
|
||||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# Attention + Q Pooling
|
# Attention + Q Pooling
|
||||||
x_norm = self.norm1(x)
|
x_norm = self.norm1(x)
|
||||||
@ -369,48 +369,10 @@ class HieraBlock(nn.Module):
|
|||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
x = x + self.drop_path1(self.attn(x_norm))
|
x = x + self.drop_path1(self.ls1(self.attn(x_norm)))
|
||||||
|
|
||||||
# MLP
|
# MLP
|
||||||
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class NormClassifierHead(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_features: int,
|
|
||||||
num_classes: int,
|
|
||||||
pool_type: str = 'avg',
|
|
||||||
drop_rate: float = 0.0,
|
|
||||||
norm_layer: Union[str, Callable] = 'layernorm',
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
norm_layer = get_norm_layer(norm_layer)
|
|
||||||
assert pool_type in ('avg', '')
|
|
||||||
self.in_features = self.num_features = in_features
|
|
||||||
self.pool_type = pool_type
|
|
||||||
self.norm = norm_layer(in_features)
|
|
||||||
self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity()
|
|
||||||
self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
||||||
|
|
||||||
def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False):
|
|
||||||
if pool_type is not None:
|
|
||||||
assert pool_type in ('avg', '')
|
|
||||||
self.pool_type = pool_type
|
|
||||||
if other:
|
|
||||||
# reset other non-fc layers
|
|
||||||
self.norm = nn.Identity()
|
|
||||||
self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
|
||||||
if self.pool_type == 'avg':
|
|
||||||
x = x.mean(dim=1)
|
|
||||||
x = self.norm(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
if pre_logits:
|
|
||||||
return x
|
|
||||||
x = self.fc(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -470,6 +432,7 @@ class Hiera(nn.Module):
|
|||||||
mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1)
|
mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1)
|
||||||
# mask_unit_attn: which stages use mask unit attention?
|
# mask_unit_attn: which stages use mask unit attention?
|
||||||
mask_unit_attn: Tuple[bool, ...] = (True, True, False, False),
|
mask_unit_attn: Tuple[bool, ...] = (True, True, False, False),
|
||||||
|
use_expand_proj: bool = True,
|
||||||
dim_mul: float = 2.0,
|
dim_mul: float = 2.0,
|
||||||
head_mul: float = 2.0,
|
head_mul: float = 2.0,
|
||||||
patch_kernel: Tuple[int, ...] = (7, 7),
|
patch_kernel: Tuple[int, ...] = (7, 7),
|
||||||
@ -477,13 +440,16 @@ class Hiera(nn.Module):
|
|||||||
patch_padding: Tuple[int, ...] = (3, 3),
|
patch_padding: Tuple[int, ...] = (3, 3),
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
drop_path_rate: float = 0.0,
|
drop_path_rate: float = 0.0,
|
||||||
|
init_values: Optional[float] = None,
|
||||||
|
fix_init: bool = True,
|
||||||
|
weight_init: str = '',
|
||||||
norm_layer: Union[str, nn.Module] = "LayerNorm",
|
norm_layer: Union[str, nn.Module] = "LayerNorm",
|
||||||
drop_rate: float = 0.0,
|
drop_rate: float = 0.0,
|
||||||
patch_drop_rate: float = 0.0,
|
patch_drop_rate: float = 0.0,
|
||||||
head_init_scale: float = 0.001,
|
head_init_scale: float = 0.001,
|
||||||
sep_pos_embed: bool = False,
|
sep_pos_embed: bool = False,
|
||||||
abs_win_pos_embed: bool = False,
|
abs_win_pos_embed: bool = False,
|
||||||
abs_pos_size: Tuple[int, int] = (14, 14),
|
global_pos_size: Tuple[int, int] = (14, 14),
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
@ -510,11 +476,9 @@ class Hiera(nn.Module):
|
|||||||
patch_kernel,
|
patch_kernel,
|
||||||
patch_stride,
|
patch_stride,
|
||||||
patch_padding,
|
patch_padding,
|
||||||
#reshape=False, # leave spatial / temporal dims in output
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pos_embed: Optional[nn.Parameter] = None
|
self.pos_embed: Optional[nn.Parameter] = None
|
||||||
self.pos_embed_abs: Optional[nn.Parameter] = None
|
|
||||||
self.pos_embed_win: Optional[nn.Parameter] = None
|
self.pos_embed_win: Optional[nn.Parameter] = None
|
||||||
self.pos_embed_spatial: Optional[nn.Parameter] = None
|
self.pos_embed_spatial: Optional[nn.Parameter] = None
|
||||||
self.pos_embed_temporal: Optional[nn.Parameter] = None
|
self.pos_embed_temporal: Optional[nn.Parameter] = None
|
||||||
@ -528,7 +492,7 @@ class Hiera(nn.Module):
|
|||||||
else:
|
else:
|
||||||
if abs_win_pos_embed:
|
if abs_win_pos_embed:
|
||||||
# absolute win, params NCHW to make tile & interpolate more natural before add & reshape
|
# absolute win, params NCHW to make tile & interpolate more natural before add & reshape
|
||||||
self.pos_embed_abs = nn.Parameter(torch.zeros(1, embed_dim, *abs_pos_size))
|
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *global_pos_size))
|
||||||
self.pos_embed_win = nn.Parameter(torch.zeros(1, embed_dim, *mask_unit_size))
|
self.pos_embed_win = nn.Parameter(torch.zeros(1, embed_dim, *mask_unit_size))
|
||||||
else:
|
else:
|
||||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
|
||||||
@ -552,7 +516,7 @@ class Hiera(nn.Module):
|
|||||||
# Transformer blocks
|
# Transformer blocks
|
||||||
cur_stage = 0
|
cur_stage = 0
|
||||||
depth = sum(stages)
|
depth = sum(stages)
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
self.blocks = nn.ModuleList()
|
self.blocks = nn.ModuleList()
|
||||||
self.feature_info = []
|
self.feature_info = []
|
||||||
for i in range(depth):
|
for i in range(depth):
|
||||||
@ -575,9 +539,11 @@ class Hiera(nn.Module):
|
|||||||
heads=num_heads,
|
heads=num_heads,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
drop_path=dpr[i],
|
drop_path=dpr[i],
|
||||||
|
init_values=init_values,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
q_stride=(flat_q_stride if i in q_pool_blocks else 1),
|
q_stride=(flat_q_stride if i in q_pool_blocks else 1),
|
||||||
window_size=flat_mu_size,
|
window_size=flat_mu_size,
|
||||||
|
use_expand_proj=use_expand_proj,
|
||||||
use_mask_unit_attn=use_mask_unit_attn,
|
use_mask_unit_attn=use_mask_unit_attn,
|
||||||
)
|
)
|
||||||
embed_dim = dim_out
|
embed_dim = dim_out
|
||||||
@ -587,12 +553,13 @@ class Hiera(nn.Module):
|
|||||||
self.blocks.append(block)
|
self.blocks.append(block)
|
||||||
|
|
||||||
self.num_features = self.head_hidden_size = embed_dim
|
self.num_features = self.head_hidden_size = embed_dim
|
||||||
self.head = NormClassifierHead(
|
self.head = ClNormMlpClassifierHead(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
num_classes,
|
num_classes,
|
||||||
pool_type=global_pool,
|
pool_type=global_pool,
|
||||||
drop_rate=drop_rate,
|
drop_rate=drop_rate,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
input_fmt='NLC',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize everything
|
# Initialize everything
|
||||||
@ -602,22 +569,26 @@ class Hiera(nn.Module):
|
|||||||
else:
|
else:
|
||||||
if self.pos_embed is not None:
|
if self.pos_embed is not None:
|
||||||
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
||||||
elif self.pos_embed_abs is not None:
|
if self.pos_embed_win is not None:
|
||||||
nn.init.trunc_normal_(self.pos_embed_abs, std=0.02)
|
|
||||||
nn.init.trunc_normal_(self.pos_embed_win, std=0.02)
|
nn.init.trunc_normal_(self.pos_embed_win, std=0.02)
|
||||||
self.apply(partial(self._init_weights))
|
|
||||||
|
if weight_init != 'skip':
|
||||||
|
init_fn = init_weight_jax if weight_init == 'jax' else init_weight_vit
|
||||||
|
init_fn = partial(init_fn, classifier_name='head.fc')
|
||||||
|
named_apply(init_fn, self)
|
||||||
|
if fix_init:
|
||||||
|
self.fix_init_weight()
|
||||||
if isinstance(self.head.fc, nn.Linear):
|
if isinstance(self.head.fc, nn.Linear):
|
||||||
self.head.fc.weight.data.mul_(head_init_scale)
|
self.head.fc.weight.data.mul_(head_init_scale)
|
||||||
self.head.fc.bias.data.mul_(head_init_scale)
|
self.head.fc.bias.data.mul_(head_init_scale)
|
||||||
|
|
||||||
def _init_weights(self, m, init_bias=0.02):
|
def fix_init_weight(self):
|
||||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
def rescale(param, _layer_id):
|
||||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
param.div_(math.sqrt(2.0 * _layer_id))
|
||||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, init_bias)
|
for layer_id, layer in enumerate(self.blocks):
|
||||||
elif isinstance(m, nn.LayerNorm):
|
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||||
nn.init.constant_(m.bias, init_bias)
|
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def no_weight_decay(self):
|
def no_weight_decay(self):
|
||||||
@ -643,9 +614,9 @@ class Hiera(nn.Module):
|
|||||||
def get_classifier(self):
|
def get_classifier(self):
|
||||||
return self.head.fc
|
return self.head.fc
|
||||||
|
|
||||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False):
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool, other=other)
|
self.head.reset(num_classes, global_pool, reset_other=reset_other)
|
||||||
|
|
||||||
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
|
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -672,20 +643,20 @@ class Hiera(nn.Module):
|
|||||||
return mask.bool()
|
return mask.bool()
|
||||||
|
|
||||||
def _pos_embed(self, x) -> torch.Tensor:
|
def _pos_embed(self, x) -> torch.Tensor:
|
||||||
if self.pos_embed is not None:
|
if self.pos_embed_win is not None:
|
||||||
pos_embed = self.pos_embed
|
|
||||||
elif self.pos_embed_abs is not None:
|
|
||||||
# absolute win position embedding, from
|
# absolute win position embedding, from
|
||||||
# Window Attention is Bugged: How not to Interpolate Position Embeddings (https://arxiv.org/abs/2311.05613)
|
# Window Attention is Bugged: How not to Interpolate Position Embeddings (https://arxiv.org/abs/2311.05613)
|
||||||
pos_embed_win = self.pos_embed_win.tile(self.mask_spatial_shape)
|
pos_embed_win = self.pos_embed_win.tile(self.mask_spatial_shape)
|
||||||
pos_embed_abs = F.interpolate(
|
pos_embed = F.interpolate(
|
||||||
self.pos_embed_abs,
|
self.pos_embed,
|
||||||
size=pos_embed_win.shape[-2:],
|
size=pos_embed_win.shape[-2:],
|
||||||
mode='bicubic',
|
mode='bicubic',
|
||||||
antialias=True,
|
antialias=True,
|
||||||
)
|
)
|
||||||
pos_embed = pos_embed_abs + pos_embed_win
|
pos_embed = pos_embed + pos_embed_win
|
||||||
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
||||||
|
elif self.pos_embed is not None:
|
||||||
|
pos_embed = self.pos_embed
|
||||||
else:
|
else:
|
||||||
pos_embed = (
|
pos_embed = (
|
||||||
self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1)
|
self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1)
|
||||||
@ -708,6 +679,7 @@ class Hiera(nn.Module):
|
|||||||
stop_early: bool = True,
|
stop_early: bool = True,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
intermediates_only: bool = False,
|
intermediates_only: bool = False,
|
||||||
|
coarse: bool = True,
|
||||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
""" Forward features that returns intermediates.
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
@ -722,10 +694,13 @@ class Hiera(nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
assert not norm, 'normalization of features not supported'
|
assert not norm, 'normalization of features not supported'
|
||||||
assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.'
|
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.'
|
||||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
if coarse:
|
||||||
take_indices = [self.stage_ends[i] for i in take_indices]
|
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||||
max_index = self.stage_ends[max_index]
|
take_indices = [self.stage_ends[i] for i in take_indices]
|
||||||
|
max_index = self.stage_ends[max_index]
|
||||||
|
else:
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
|
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
|
||||||
@ -747,7 +722,8 @@ class Hiera(nn.Module):
|
|||||||
for i, blk in enumerate(blocks):
|
for i, blk in enumerate(blocks):
|
||||||
x = blk(x)
|
x = blk(x)
|
||||||
if i in take_indices:
|
if i in take_indices:
|
||||||
intermediates.append(self.reroll(x, i, mask=mask).permute(0, 3, 1, 2))
|
x_int = self.reroll(x, i, mask=mask)
|
||||||
|
intermediates.append(x_int.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x_int)
|
||||||
|
|
||||||
if intermediates_only:
|
if intermediates_only:
|
||||||
return intermediates
|
return intermediates
|
||||||
@ -759,14 +735,18 @@ class Hiera(nn.Module):
|
|||||||
indices: Union[int, List[int]] = 1,
|
indices: Union[int, List[int]] = 1,
|
||||||
prune_norm: bool = False,
|
prune_norm: bool = False,
|
||||||
prune_head: bool = True,
|
prune_head: bool = True,
|
||||||
|
coarse: bool = True,
|
||||||
):
|
):
|
||||||
""" Prune layers not required for specified intermediates.
|
""" Prune layers not required for specified intermediates.
|
||||||
"""
|
"""
|
||||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
if coarse:
|
||||||
max_index = self.stage_ends[max_index]
|
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||||
|
max_index = self.stage_ends[max_index]
|
||||||
|
else:
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||||
if prune_head:
|
if prune_head:
|
||||||
self.head.reset(0, other=True)
|
self.head.reset(0, reset_other=True)
|
||||||
return take_indices
|
return take_indices
|
||||||
|
|
||||||
def forward_features(
|
def forward_features(
|
||||||
@ -901,8 +881,22 @@ default_cfgs = generate_default_cfgs({
|
|||||||
num_classes=0,
|
num_classes=0,
|
||||||
),
|
),
|
||||||
|
|
||||||
"hiera_small_abswin_256.untrained": _cfg(
|
"hiera_small_abswin_256.sbb2_e200_in12k_ft_in1k": _cfg(
|
||||||
#hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95,
|
||||||
|
),
|
||||||
|
"hiera_small_abswin_256.sbb2_pd_e200_in12k_ft_in1k": _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95,
|
||||||
|
),
|
||||||
|
"hiera_small_abswin_256.sbb2_e200_in12k": _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=11821,
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95,
|
||||||
|
),
|
||||||
|
"hiera_small_abswin_256.sbb2_pd_e200_in12k": _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=11821,
|
||||||
input_size=(3, 256, 256), crop_pct=0.95,
|
input_size=(3, 256, 256), crop_pct=0.95,
|
||||||
),
|
),
|
||||||
"hiera_base_abswin_256.untrained": _cfg(
|
"hiera_base_abswin_256.untrained": _cfg(
|
||||||
@ -931,6 +925,8 @@ def checkpoint_filter_fn(state_dict, model=None):
|
|||||||
k = k.replace('encoder_norm.', 'head.norm.')
|
k = k.replace('encoder_norm.', 'head.norm.')
|
||||||
elif k.startswith('norm.'):
|
elif k.startswith('norm.'):
|
||||||
k = k.replace('norm.', 'head.norm.')
|
k = k.replace('norm.', 'head.norm.')
|
||||||
|
if k == 'pos_embed_abs':
|
||||||
|
k = 'pos_embed'
|
||||||
output[k] = v
|
output[k] = v
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -947,6 +943,7 @@ def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def hiera_tiny_224(pretrained=False, **kwargs):
|
def hiera_tiny_224(pretrained=False, **kwargs):
|
||||||
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2))
|
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2))
|
||||||
@ -985,11 +982,15 @@ def hiera_huge_224(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def hiera_small_abswin_256(pretrained=False, **kwargs):
|
def hiera_small_abswin_256(pretrained=False, **kwargs):
|
||||||
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16))
|
model_args = dict(
|
||||||
|
embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, global_pos_size=(16, 16),
|
||||||
|
init_values=1e-5, weight_init='jax', use_expand_proj=False,
|
||||||
|
)
|
||||||
return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def hiera_base_abswin_256(pretrained=False, **kwargs):
|
def hiera_base_abswin_256(pretrained=False, **kwargs):
|
||||||
model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16))
|
model_args = dict(
|
||||||
return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, init_values=1e-5, weight_init='jax')
|
||||||
|
return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
635
timm/models/hieradet_sam2.py
Normal file
635
timm/models/hieradet_sam2.py
Normal file
@ -0,0 +1,635 @@
|
|||||||
|
import math
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.jit import Final
|
||||||
|
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, LayerScale, \
|
||||||
|
get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn
|
||||||
|
|
||||||
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
|
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||||
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
|
|
||||||
|
|
||||||
|
def window_partition(x, window_size: Tuple[int, int]):
|
||||||
|
"""
|
||||||
|
Partition into non-overlapping windows with padding if needed.
|
||||||
|
Args:
|
||||||
|
x (tensor): input tokens with [B, H, W, C].
|
||||||
|
window_size (int): window size.
|
||||||
|
Returns:
|
||||||
|
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
||||||
|
(Hp, Wp): padded height and width before partition
|
||||||
|
"""
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
|
||||||
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def window_unpartition(windows: torch.Tensor, window_size: Tuple[int, int], hw: Tuple[int, int]):
|
||||||
|
"""
|
||||||
|
Window unpartition into original sequences and removing padding.
|
||||||
|
Args:
|
||||||
|
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
||||||
|
window_size (int): window size.
|
||||||
|
hw (Tuple): original height and width (H, W) before padding.
|
||||||
|
Returns:
|
||||||
|
x: unpartitioned sequences with [B, H, W, C].
|
||||||
|
"""
|
||||||
|
H, W = hw
|
||||||
|
B = windows.shape[0] // (H * W // window_size[0] // window_size[1])
|
||||||
|
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
|
||||||
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_pad(H: int, W: int, window_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
|
||||||
|
pad_h = (window_size[0] - H % window_size[0]) % window_size[0]
|
||||||
|
pad_w = (window_size[1] - W % window_size[1]) % window_size[1]
|
||||||
|
Hp, Wp = H + pad_h, W + pad_w
|
||||||
|
return Hp, Wp, pad_h, pad_w
|
||||||
|
|
||||||
|
|
||||||
|
class MultiScaleAttention(nn.Module):
|
||||||
|
fused_attn: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
dim_out: int,
|
||||||
|
num_heads: int,
|
||||||
|
q_pool: nn.Module = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.dim_out = dim_out
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim_out // num_heads
|
||||||
|
self.scale = head_dim ** -0.5
|
||||||
|
self.fused_attn = use_fused_attn()
|
||||||
|
|
||||||
|
self.q_pool = q_pool
|
||||||
|
self.qkv = nn.Linear(dim, dim_out * 3)
|
||||||
|
self.proj = nn.Linear(dim_out, dim_out)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
B, H, W, _ = x.shape
|
||||||
|
|
||||||
|
# qkv with shape (B, H * W, 3, nHead, C)
|
||||||
|
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
||||||
|
|
||||||
|
# q, k, v with shape (B, H * W, nheads, C)
|
||||||
|
q, k, v = torch.unbind(qkv, 2)
|
||||||
|
|
||||||
|
# Q pooling (for downsample at stage changes)
|
||||||
|
if self.q_pool is not None:
|
||||||
|
q = q.reshape(B, H, W, -1).permute(0, 3, 1, 2) # to BCHW for pool
|
||||||
|
q = self.q_pool(q).permute(0, 2, 3, 1)
|
||||||
|
H, W = q.shape[1:3] # downsampled shape
|
||||||
|
q = q.reshape(B, H * W, self.num_heads, -1)
|
||||||
|
|
||||||
|
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
if self.fused_attn:
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-1, -2)
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
x = attn @ v
|
||||||
|
|
||||||
|
# Transpose back
|
||||||
|
x = x.transpose(1, 2).reshape(B, H, W, -1)
|
||||||
|
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiScaleBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
dim_out: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
q_stride: Optional[Tuple[int, int]] = None,
|
||||||
|
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
||||||
|
act_layer: Union[nn.Module, str] = "GELU",
|
||||||
|
window_size: int = 0,
|
||||||
|
init_values: Optional[float] = None,
|
||||||
|
drop_path: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
norm_layer = get_norm_layer(norm_layer)
|
||||||
|
act_layer = get_act_layer(act_layer)
|
||||||
|
self.window_size = to_2tuple(window_size)
|
||||||
|
self.is_windowed = any(self.window_size)
|
||||||
|
self.dim = dim
|
||||||
|
self.dim_out = dim_out
|
||||||
|
self.q_stride = q_stride
|
||||||
|
|
||||||
|
if dim != dim_out:
|
||||||
|
self.proj = nn.Linear(dim, dim_out)
|
||||||
|
else:
|
||||||
|
self.proj = nn.Identity()
|
||||||
|
self.pool = None
|
||||||
|
if self.q_stride:
|
||||||
|
# note make a different instance for this Module so that it's not shared with attn module
|
||||||
|
self.pool = nn.MaxPool2d(
|
||||||
|
kernel_size=q_stride,
|
||||||
|
stride=q_stride,
|
||||||
|
ceil_mode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = MultiScaleAttention(
|
||||||
|
dim,
|
||||||
|
dim_out,
|
||||||
|
num_heads=num_heads,
|
||||||
|
q_pool=deepcopy(self.pool),
|
||||||
|
)
|
||||||
|
self.ls1 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity()
|
||||||
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
self.norm2 = norm_layer(dim_out)
|
||||||
|
self.mlp = Mlp(
|
||||||
|
dim_out,
|
||||||
|
int(dim_out * mlp_ratio),
|
||||||
|
act_layer=act_layer,
|
||||||
|
)
|
||||||
|
self.ls2 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity()
|
||||||
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
shortcut = x # B, H, W, C
|
||||||
|
x = self.norm1(x)
|
||||||
|
|
||||||
|
# Skip connection
|
||||||
|
if self.dim != self.dim_out:
|
||||||
|
shortcut = self.proj(x)
|
||||||
|
if self.pool is not None:
|
||||||
|
shortcut = shortcut.permute(0, 3, 1, 2)
|
||||||
|
shortcut = self.pool(shortcut).permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
# Window partition
|
||||||
|
window_size = self.window_size
|
||||||
|
H, W = x.shape[1:3]
|
||||||
|
Hp, Wp = H, W # keep torchscript happy
|
||||||
|
if self.is_windowed:
|
||||||
|
Hp, Wp, pad_h, pad_w = _calc_pad(H, W, window_size)
|
||||||
|
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||||
|
x = window_partition(x, window_size)
|
||||||
|
|
||||||
|
# Window Attention + Q Pooling (if stage change)
|
||||||
|
x = self.attn(x)
|
||||||
|
if self.q_stride is not None:
|
||||||
|
# Shapes have changed due to Q pooling
|
||||||
|
window_size = (self.window_size[0] // self.q_stride[0], self.window_size[1] // self.q_stride[1])
|
||||||
|
H, W = shortcut.shape[1:3]
|
||||||
|
Hp, Wp, pad_h, pad_w = _calc_pad(H, W, window_size)
|
||||||
|
|
||||||
|
# Reverse window partition
|
||||||
|
if self.is_windowed:
|
||||||
|
x = window_unpartition(x, window_size, (Hp, Wp))
|
||||||
|
x = x[:, :H, :W, :].contiguous() # unpad
|
||||||
|
|
||||||
|
x = shortcut + self.drop_path1(self.ls1(x))
|
||||||
|
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HieraPatchEmbed(nn.Module):
|
||||||
|
"""
|
||||||
|
Image to Patch Embedding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size: Tuple[int, ...] = (7, 7),
|
||||||
|
stride: Tuple[int, ...] = (4, 4),
|
||||||
|
padding: Tuple[int, ...] = (3, 3),
|
||||||
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 768,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
kernel_size (Tuple): kernel size of the projection layer.
|
||||||
|
stride (Tuple): stride of the projection layer.
|
||||||
|
padding (Tuple): padding size of the projection layer.
|
||||||
|
in_chans (int): Number of input image channels.
|
||||||
|
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Conv2d(
|
||||||
|
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.proj(x)
|
||||||
|
# B C H W -> B H W C
|
||||||
|
x = x.permute(0, 2, 3, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HieraDet(nn.Module):
|
||||||
|
"""
|
||||||
|
Reference: https://arxiv.org/abs/2306.00989
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_chans: int = 3,
|
||||||
|
num_classes: int = 1000,
|
||||||
|
global_pool: str = 'avg',
|
||||||
|
embed_dim: int = 96, # initial embed dim
|
||||||
|
num_heads: int = 1, # initial number of heads
|
||||||
|
patch_kernel: Tuple[int, ...] = (7, 7),
|
||||||
|
patch_stride: Tuple[int, ...] = (4, 4),
|
||||||
|
patch_padding: Tuple[int, ...] = (3, 3),
|
||||||
|
patch_size: Optional[Tuple[int, ...]] = None,
|
||||||
|
q_pool: int = 3, # number of q_pool stages
|
||||||
|
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
||||||
|
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
||||||
|
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
||||||
|
head_mul: float = 2.0, # head_mul factor at stage shift
|
||||||
|
global_pos_size: Tuple[int, int] = (7, 7),
|
||||||
|
# window size per stage, when not using global att.
|
||||||
|
window_spec: Tuple[int, ...] = (
|
||||||
|
8,
|
||||||
|
4,
|
||||||
|
14,
|
||||||
|
7,
|
||||||
|
),
|
||||||
|
# global attn in these blocks
|
||||||
|
global_att_blocks: Tuple[int, ...] = (
|
||||||
|
12,
|
||||||
|
16,
|
||||||
|
20,
|
||||||
|
),
|
||||||
|
init_values: Optional[float] = None,
|
||||||
|
weight_init: str = '',
|
||||||
|
fix_init: bool = True,
|
||||||
|
head_init_scale: float = 0.001,
|
||||||
|
drop_rate: float = 0.0,
|
||||||
|
drop_path_rate: float = 0.0, # stochastic depth
|
||||||
|
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
||||||
|
act_layer: Union[nn.Module, str] = "GELU",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
norm_layer = get_norm_layer(norm_layer)
|
||||||
|
act_layer = get_act_layer(act_layer)
|
||||||
|
assert len(stages) == len(window_spec)
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.window_spec = window_spec
|
||||||
|
self.output_fmt = 'NHWC'
|
||||||
|
|
||||||
|
depth = sum(stages)
|
||||||
|
self.q_stride = q_stride
|
||||||
|
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
||||||
|
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
||||||
|
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
||||||
|
|
||||||
|
if patch_size is not None:
|
||||||
|
# use a non-overlapping vit style patch embed
|
||||||
|
self.patch_embed = PatchEmbed(
|
||||||
|
img_size=None,
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_chans,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
output_fmt='NHWC',
|
||||||
|
dynamic_img_pad=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.patch_embed = HieraPatchEmbed(
|
||||||
|
kernel_size=patch_kernel,
|
||||||
|
stride=patch_stride,
|
||||||
|
padding=patch_padding,
|
||||||
|
in_chans=in_chans,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
)
|
||||||
|
# Which blocks have global att?
|
||||||
|
self.global_att_blocks = global_att_blocks
|
||||||
|
|
||||||
|
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
||||||
|
self.global_pos_size = global_pos_size
|
||||||
|
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.global_pos_size))
|
||||||
|
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
|
||||||
|
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
|
cur_stage = 0
|
||||||
|
self.blocks = nn.Sequential()
|
||||||
|
self.feature_info = []
|
||||||
|
for i in range(depth):
|
||||||
|
dim_out = embed_dim
|
||||||
|
# lags by a block, so first block of
|
||||||
|
# next stage uses an initial window size
|
||||||
|
# of previous stage and final window size of current stage
|
||||||
|
window_size = self.window_spec[cur_stage]
|
||||||
|
|
||||||
|
if self.global_att_blocks is not None:
|
||||||
|
window_size = 0 if i in self.global_att_blocks else window_size
|
||||||
|
|
||||||
|
if i - 1 in self.stage_ends:
|
||||||
|
dim_out = int(embed_dim * dim_mul)
|
||||||
|
num_heads = int(num_heads * head_mul)
|
||||||
|
cur_stage += 1
|
||||||
|
|
||||||
|
block = MultiScaleBlock(
|
||||||
|
dim=embed_dim,
|
||||||
|
dim_out=dim_out,
|
||||||
|
num_heads=num_heads,
|
||||||
|
drop_path=dpr[i],
|
||||||
|
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
||||||
|
window_size=window_size,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act_layer=act_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
embed_dim = dim_out
|
||||||
|
self.blocks.append(block)
|
||||||
|
if i in self.stage_ends:
|
||||||
|
self.feature_info += [
|
||||||
|
dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
|
||||||
|
|
||||||
|
self.num_features = self.head_hidden_size = embed_dim
|
||||||
|
self.head = ClNormMlpClassifierHead(
|
||||||
|
embed_dim,
|
||||||
|
num_classes,
|
||||||
|
pool_type=global_pool,
|
||||||
|
drop_rate=drop_rate,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize everything
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
||||||
|
|
||||||
|
if self.pos_embed_window is not None:
|
||||||
|
nn.init.trunc_normal_(self.pos_embed_window, std=0.02)
|
||||||
|
|
||||||
|
if weight_init != 'skip':
|
||||||
|
init_fn = init_weight_jax if weight_init == 'jax' else init_weight_vit
|
||||||
|
init_fn = partial(init_fn, classifier_name='head.fc')
|
||||||
|
named_apply(init_fn, self)
|
||||||
|
|
||||||
|
if fix_init:
|
||||||
|
self.fix_init_weight()
|
||||||
|
|
||||||
|
if isinstance(self.head, ClNormMlpClassifierHead) and isinstance(self.head.fc, nn.Linear):
|
||||||
|
self.head.fc.weight.data.mul_(head_init_scale)
|
||||||
|
self.head.fc.bias.data.mul_(head_init_scale)
|
||||||
|
|
||||||
|
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
h, w = x.shape[1:3]
|
||||||
|
window_embed = self.pos_embed_window
|
||||||
|
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
||||||
|
tile_h = pos_embed.shape[-2] // window_embed.shape[-2]
|
||||||
|
tile_w = pos_embed.shape[-1] // window_embed.shape[-1]
|
||||||
|
pos_embed = pos_embed + window_embed.tile((tile_h, tile_w))
|
||||||
|
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
||||||
|
return x + pos_embed
|
||||||
|
|
||||||
|
def fix_init_weight(self):
|
||||||
|
def rescale(param, _layer_id):
|
||||||
|
param.div_(math.sqrt(2.0 * _layer_id))
|
||||||
|
|
||||||
|
for layer_id, layer in enumerate(self.blocks):
|
||||||
|
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||||
|
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def no_weight_decay(self):
|
||||||
|
return ['pos_embed', 'pos_embed_window']
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def group_matcher(self, coarse: bool = False) -> Dict:
|
||||||
|
return dict(
|
||||||
|
stem=r'^pos_embed|pos_embed_window|patch_embed',
|
||||||
|
blocks=[(r'^blocks\.(\d+)', None)]
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
||||||
|
self.grad_checkpointing = enable
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.head.fc
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.head.reset(num_classes, pool_type=global_pool, reset_other=reset_other)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = True,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
coarse: bool = True,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to all intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
coarse: Take coarse features (stage ends) if true, otherwise all block featrures
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert not norm, 'normalization of features not supported'
|
||||||
|
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.'
|
||||||
|
if coarse:
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||||
|
take_indices = [self.stage_ends[i] for i in take_indices]
|
||||||
|
max_index = self.stage_ends[max_index]
|
||||||
|
else:
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
|
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
x = self._pos_embed(x)
|
||||||
|
|
||||||
|
intermediates = []
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
blocks = self.blocks
|
||||||
|
else:
|
||||||
|
blocks = self.blocks[:max_index + 1]
|
||||||
|
for i, blk in enumerate(blocks):
|
||||||
|
x = blk(x)
|
||||||
|
if i in take_indices:
|
||||||
|
x_out = x.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x
|
||||||
|
intermediates.append(x_out)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
coarse: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
if coarse:
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||||
|
max_index = self.stage_ends[max_index]
|
||||||
|
else:
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
|
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||||
|
if prune_head:
|
||||||
|
self.head.reset(0, reset_other=prune_norm)
|
||||||
|
return take_indices
|
||||||
|
|
||||||
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.patch_embed(x) # BHWC
|
||||||
|
x = self._pos_embed(x)
|
||||||
|
for i, blk in enumerate(self.blocks):
|
||||||
|
x = blk(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
|
||||||
|
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = self.forward_head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE sam2 appears to use 1024x1024 for all models, but T, S, & B+ have windows that fit multiples of 224.
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 0, 'input_size': (3, 896, 896), 'pool_size': (28, 28),
|
||||||
|
'crop_pct': 1.0, 'interpolation': 'bicubic', 'min_input_size': (3, 224, 224),
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = generate_default_cfgs({
|
||||||
|
"sam2_hiera_tiny.r224": _cfg(
|
||||||
|
hf_hub_id='facebook/sam2-hiera-tiny',
|
||||||
|
hf_hub_filename='sam2_hiera_tiny.pt',
|
||||||
|
input_size=(3, 224, 224), pool_size=(7, 7),
|
||||||
|
), # FIXME reduced res for testing
|
||||||
|
"sam2_hiera_tiny.r896": _cfg(
|
||||||
|
hf_hub_id='facebook/sam2-hiera-tiny',
|
||||||
|
hf_hub_filename='sam2_hiera_tiny.pt',
|
||||||
|
),
|
||||||
|
"sam2_hiera_small": _cfg(
|
||||||
|
hf_hub_id='facebook/sam2-hiera-small',
|
||||||
|
hf_hub_filename='sam2_hiera_small.pt',
|
||||||
|
),
|
||||||
|
"sam2_hiera_base_plus": _cfg(
|
||||||
|
hf_hub_id='facebook/sam2-hiera-base-plus',
|
||||||
|
hf_hub_filename='sam2_hiera_base_plus.pt',
|
||||||
|
),
|
||||||
|
"sam2_hiera_large": _cfg(
|
||||||
|
hf_hub_id='facebook/sam2-hiera-large',
|
||||||
|
hf_hub_filename='sam2_hiera_large.pt',
|
||||||
|
min_input_size=(3, 256, 256),
|
||||||
|
input_size=(3, 1024, 1024), pool_size=(32, 32),
|
||||||
|
),
|
||||||
|
"hieradet_small.untrained": _cfg(
|
||||||
|
num_classes=1000,
|
||||||
|
input_size=(3, 256, 256), pool_size=(8, 8),
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_filter_fn(state_dict, model=None, prefix=''):
|
||||||
|
state_dict = state_dict.get('model', state_dict)
|
||||||
|
|
||||||
|
output = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if k.startswith(prefix):
|
||||||
|
k = k.replace(prefix, '')
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
k = k.replace('mlp.layers.0', 'mlp.fc1')
|
||||||
|
k = k.replace('mlp.layers.1', 'mlp.fc2')
|
||||||
|
output[k] = v
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _create_hiera_det(variant: str, pretrained: bool = False, **kwargs) -> HieraDet:
|
||||||
|
out_indices = kwargs.pop('out_indices', 4)
|
||||||
|
checkpoint_prefix = ''
|
||||||
|
if 'sam2' in variant:
|
||||||
|
# SAM2 pretrained weights have no classifier or final norm-layer (`head.norm`)
|
||||||
|
# This is workaround loading with num_classes=0 w/o removing norm-layer.
|
||||||
|
kwargs.setdefault('pretrained_strict', False)
|
||||||
|
checkpoint_prefix = 'image_encoder.trunk.'
|
||||||
|
return build_model_with_cfg(
|
||||||
|
HieraDet,
|
||||||
|
variant,
|
||||||
|
pretrained,
|
||||||
|
pretrained_filter_fn=partial(checkpoint_filter_fn, prefix=checkpoint_prefix),
|
||||||
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def sam2_hiera_tiny(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(stages=(1, 2, 7, 2), global_att_blocks=(5, 7, 9))
|
||||||
|
return _create_hiera_det('sam2_hiera_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def sam2_hiera_small(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13))
|
||||||
|
return _create_hiera_det('sam2_hiera_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def sam2_hiera_base_plus(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(embed_dim=112, num_heads=2, global_pos_size=(14, 14))
|
||||||
|
return _create_hiera_det('sam2_hiera_base_plus', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def sam2_hiera_large(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(
|
||||||
|
embed_dim=144,
|
||||||
|
num_heads=2,
|
||||||
|
stages=(2, 6, 36, 4),
|
||||||
|
global_att_blocks=(23, 33, 43),
|
||||||
|
window_spec=(8, 4, 16, 8),
|
||||||
|
)
|
||||||
|
return _create_hiera_det('sam2_hiera_large', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def hieradet_small(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8), init_values=1e-5)
|
||||||
|
return _create_hiera_det('hieradet_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
# @register_model
|
||||||
|
# def hieradet_base(pretrained=False, **kwargs):
|
||||||
|
# model_args = dict(window_spec=(8, 4, 16, 8))
|
||||||
|
# return _create_hiera_det('hieradet_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
@ -783,6 +783,11 @@ default_cfgs = generate_default_cfgs({
|
|||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
|
||||||
first_conv='conv1.0'),
|
first_conv='conv1.0'),
|
||||||
|
'resnet50d.ra4_e3600_r224_in1k': _rcfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||||
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0,
|
||||||
|
first_conv='conv1.0'),
|
||||||
'resnet50d.a1_in1k': _rcfg(
|
'resnet50d.a1_in1k': _rcfg(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a1_0-e20cff14.pth',
|
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a1_0-e20cff14.pth',
|
||||||
|
@ -1964,26 +1964,46 @@ default_cfgs = {
|
|||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
num_classes=11821,
|
num_classes=11821,
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
|
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=11821,
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg(
|
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
num_classes=11821,
|
num_classes=11821,
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_mediumd_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg(
|
'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
|
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
|
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=11821,
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg(
|
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
num_classes=11821,
|
num_classes=11821,
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_betwixt_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
'vit_base_patch16_reg4_gap_256.untrained': _cfg(
|
'vit_base_patch16_reg4_gap_256.untrained': _cfg(
|
||||||
input_size=(3, 256, 256)),
|
input_size=(3, 256, 256)),
|
||||||
|
|
||||||
@ -3110,6 +3130,17 @@ def vit_mediumd_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visi
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_mediumd_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
|
||||||
|
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_mediumd_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
@ -3132,6 +3163,17 @@ def vit_betwixt_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visi
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_betwixt_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
|
||||||
|
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_betwixt_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user