mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Make gcvit window size ratio based to improve resolution changing support #1449. Change default init to original.
This commit is contained in:
parent
c45c6ee8e4
commit
f489f02ad1
@ -30,8 +30,8 @@ import torch.utils.checkpoint as checkpoint
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .fx_features import register_notrace_function
|
from .fx_features import register_notrace_function
|
||||||
from .helpers import build_model_with_cfg, named_apply
|
from .helpers import build_model_with_cfg, named_apply
|
||||||
from .layers import trunc_normal_tf_, DropPath, to_2tuple, Mlp, get_attn, get_act_layer, get_norm_layer, \
|
from .layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d,\
|
||||||
ClassifierHead, LayerNorm2d, _assert
|
get_attn, get_act_layer, get_norm_layer, _assert
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location
|
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location
|
||||||
|
|
||||||
@ -321,7 +321,7 @@ class GlobalContextVitStage(nn.Module):
|
|||||||
depth: int,
|
depth: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
feat_size: Tuple[int, int],
|
feat_size: Tuple[int, int],
|
||||||
window_size: int,
|
window_size: Tuple[int, int],
|
||||||
downsample: bool = True,
|
downsample: bool = True,
|
||||||
global_norm: bool = False,
|
global_norm: bool = False,
|
||||||
stage_norm: bool = False,
|
stage_norm: bool = False,
|
||||||
@ -347,8 +347,9 @@ class GlobalContextVitStage(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.downsample = nn.Identity()
|
self.downsample = nn.Identity()
|
||||||
self.feat_size = feat_size
|
self.feat_size = feat_size
|
||||||
|
window_size = to_2tuple(window_size)
|
||||||
|
|
||||||
feat_levels = int(math.log2(min(feat_size) / window_size))
|
feat_levels = int(math.log2(min(feat_size) / min(window_size)))
|
||||||
self.global_block = FeatureBlock(dim, feat_levels)
|
self.global_block = FeatureBlock(dim, feat_levels)
|
||||||
self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity()
|
self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity()
|
||||||
|
|
||||||
@ -400,7 +401,8 @@ class GlobalContextVit(nn.Module):
|
|||||||
num_classes: int = 1000,
|
num_classes: int = 1000,
|
||||||
global_pool: str = 'avg',
|
global_pool: str = 'avg',
|
||||||
img_size: Tuple[int, int] = 224,
|
img_size: Tuple[int, int] = 224,
|
||||||
window_size: Tuple[int, ...] = (7, 7, 14, 7),
|
window_ratio: Tuple[int, ...] = (32, 32, 16, 32),
|
||||||
|
window_size: Tuple[int, ...] = None,
|
||||||
embed_dim: int = 64,
|
embed_dim: int = 64,
|
||||||
depths: Tuple[int, ...] = (3, 4, 19, 5),
|
depths: Tuple[int, ...] = (3, 4, 19, 5),
|
||||||
num_heads: Tuple[int, ...] = (2, 4, 8, 16),
|
num_heads: Tuple[int, ...] = (2, 4, 8, 16),
|
||||||
@ -411,7 +413,7 @@ class GlobalContextVit(nn.Module):
|
|||||||
proj_drop_rate: float = 0.,
|
proj_drop_rate: float = 0.,
|
||||||
attn_drop_rate: float = 0.,
|
attn_drop_rate: float = 0.,
|
||||||
drop_path_rate: float = 0.,
|
drop_path_rate: float = 0.,
|
||||||
weight_init='vit',
|
weight_init='',
|
||||||
act_layer: str = 'gelu',
|
act_layer: str = 'gelu',
|
||||||
norm_layer: str = 'layernorm2d',
|
norm_layer: str = 'layernorm2d',
|
||||||
norm_layer_cl: str = 'layernorm',
|
norm_layer_cl: str = 'layernorm',
|
||||||
@ -429,6 +431,11 @@ class GlobalContextVit(nn.Module):
|
|||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
num_stages = len(depths)
|
num_stages = len(depths)
|
||||||
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
|
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
|
||||||
|
if window_size is not None:
|
||||||
|
window_size = to_ntuple(num_stages)(window_size)
|
||||||
|
else:
|
||||||
|
assert window_ratio is not None
|
||||||
|
window_size = tuple([(img_size[0] // r, img_size[1] // r) for r in to_ntuple(num_stages)(window_ratio)])
|
||||||
|
|
||||||
self.stem = Stem(
|
self.stem = Stem(
|
||||||
in_chs=in_chans,
|
in_chs=in_chans,
|
||||||
@ -480,7 +487,7 @@ class GlobalContextVit(nn.Module):
|
|||||||
nn.init.zeros_(module.bias)
|
nn.init.zeros_(module.bias)
|
||||||
else:
|
else:
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
trunc_normal_tf_(module.weight, std=.02)
|
nn.init.normal_(module.weight, std=.02)
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
nn.init.zeros_(module.bias)
|
nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
@ -490,7 +497,6 @@ class GlobalContextVit(nn.Module):
|
|||||||
k for k, _ in self.named_parameters()
|
k for k, _ in self.named_parameters()
|
||||||
if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
|
if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def group_matcher(self, coarse=False):
|
def group_matcher(self, coarse=False):
|
||||||
matcher = dict(
|
matcher = dict(
|
||||||
@ -567,7 +573,6 @@ def gcvit_small(pretrained=False, **kwargs):
|
|||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
depths=(3, 4, 19, 5),
|
depths=(3, 4, 19, 5),
|
||||||
num_heads=(3, 6, 12, 24),
|
num_heads=(3, 6, 12, 24),
|
||||||
window_size=(7, 7, 14, 7),
|
|
||||||
embed_dim=96,
|
embed_dim=96,
|
||||||
mlp_ratio=2,
|
mlp_ratio=2,
|
||||||
layer_scale=1e-5,
|
layer_scale=1e-5,
|
||||||
@ -580,7 +585,6 @@ def gcvit_base(pretrained=False, **kwargs):
|
|||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
depths=(3, 4, 19, 5),
|
depths=(3, 4, 19, 5),
|
||||||
num_heads=(4, 8, 16, 32),
|
num_heads=(4, 8, 16, 32),
|
||||||
window_size=(7, 7, 14, 7),
|
|
||||||
embed_dim=128,
|
embed_dim=128,
|
||||||
mlp_ratio=2,
|
mlp_ratio=2,
|
||||||
layer_scale=1e-5,
|
layer_scale=1e-5,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user