mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove persistent buffers from Swin-V2. Change SwinV2Cr cos attn + tau/logit_scale to match official, add ckpt convert, init_value zeros resid LN weight by default
This commit is contained in:
parent
27c42f0830
commit
d4c0588012
@ -25,7 +25,6 @@ from .fx_features import register_notrace_function
|
||||
from .helpers import build_model_with_cfg, named_apply
|
||||
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert
|
||||
from .registry import register_model
|
||||
from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -75,7 +74,7 @@ default_cfgs = {
|
||||
),
|
||||
'swinv2_base_window12to24_192to384_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
|
||||
input_size=(3, 384, 384)
|
||||
input_size=(3, 384, 384), crop_pct=1.0,
|
||||
),
|
||||
'swinv2_large_window12_192_22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
|
||||
@ -87,7 +86,7 @@ default_cfgs = {
|
||||
),
|
||||
'swinv2_large_window12to24_192to384_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
|
||||
input_size=(3, 384, 384)
|
||||
input_size=(3, 384, 384), crop_pct=1.0,
|
||||
),
|
||||
}
|
||||
|
||||
@ -174,7 +173,7 @@ class WindowAttention(nn.Module):
|
||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
||||
torch.abs(relative_coords_table) + 1.0) / math.log2(8)
|
||||
|
||||
self.register_buffer("relative_coords_table", relative_coords_table)
|
||||
self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
@ -187,7 +186,7 @@ class WindowAttention(nn.Module):
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
@ -215,7 +214,7 @@ class WindowAttention(nn.Module):
|
||||
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
# cosine attention
|
||||
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
||||
@ -559,9 +558,6 @@ class SwinTransformerV2(nn.Module):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
@ -621,6 +617,18 @@ class SwinTransformerV2(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
out_dict = {}
|
||||
if 'model' in state_dict:
|
||||
# For deit models
|
||||
state_dict = state_dict['model']
|
||||
for k, v in state_dict.items():
|
||||
if any([n in k for n in ('relative_position_index', 'relative_coords_table')]):
|
||||
continue # skip buffers that should not be persistent
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_swin_transformer_v2(variant, pretrained=False, **kwargs):
|
||||
model = build_model_with_cfg(
|
||||
SwinTransformerV2, variant, pretrained,
|
||||
|
@ -34,6 +34,7 @@ from typing import Tuple, Optional, List, Union, Any, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
@ -41,7 +42,7 @@ from .fx_features import register_notrace_function
|
||||
from .helpers import build_model_with_cfg, named_apply
|
||||
from .layers import DropPath, Mlp, to_2tuple, _assert
|
||||
from .registry import register_model
|
||||
from .vision_transformer import checkpoint_filter_fn
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@ -186,12 +187,13 @@ class WindowMultiHeadAttention(nn.Module):
|
||||
act_layer=nn.ReLU,
|
||||
drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without?
|
||||
)
|
||||
self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads)))
|
||||
# NOTE old checkpoints used inverse of logit_scale ('tau') following the paper, see conversion fn
|
||||
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones(num_heads)))
|
||||
self._make_pair_wise_relative_positions()
|
||||
|
||||
def _make_pair_wise_relative_positions(self) -> None:
|
||||
"""Method initializes the pair-wise relative positions to compute the positional biases."""
|
||||
device = self.tau.device
|
||||
device = self.logit_scale.device
|
||||
coordinates = torch.stack(torch.meshgrid([
|
||||
torch.arange(self.window_size[0], device=device),
|
||||
torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1)
|
||||
@ -250,10 +252,11 @@ class WindowMultiHeadAttention(nn.Module):
|
||||
query, key, value = qkv.unbind(0)
|
||||
|
||||
# compute attention map with scaled cosine attention
|
||||
denom = torch.norm(query, dim=-1, keepdim=True) @ torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1)
|
||||
attn = query @ key.transpose(-2, -1) / denom.clamp(min=1e-6)
|
||||
attn = attn / self.tau.clamp(min=0.01).reshape(1, self.num_heads, 1, 1)
|
||||
attn = (F.normalize(query, dim=-1) @ F.normalize(key, dim=-1).transpose(-2, -1))
|
||||
logit_scale = torch.clamp(self.logit_scale.reshape(1, self.num_heads, 1, 1), max=math.log(1. / 0.01)).exp()
|
||||
attn = attn * logit_scale
|
||||
attn = attn + self._relative_positional_encodings()
|
||||
|
||||
if mask is not None:
|
||||
# Apply mask if utilized
|
||||
num_win: int = mask.shape[0]
|
||||
@ -309,7 +312,7 @@ class SwinTransformerBlock(nn.Module):
|
||||
window_size: Tuple[int, int],
|
||||
shift_size: Tuple[int, int] = (0, 0),
|
||||
mlp_ratio: float = 4.0,
|
||||
init_values: float = 0,
|
||||
init_values: Optional[float] = 0,
|
||||
drop: float = 0.0,
|
||||
drop_attn: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
@ -323,7 +326,7 @@ class SwinTransformerBlock(nn.Module):
|
||||
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
|
||||
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.init_values: float = init_values
|
||||
self.init_values: Optional[float] = init_values
|
||||
|
||||
# attn branch
|
||||
self.attn = WindowMultiHeadAttention(
|
||||
@ -387,7 +390,7 @@ class SwinTransformerBlock(nn.Module):
|
||||
|
||||
def init_weights(self):
|
||||
# extra, module specific weight init
|
||||
if self.init_values:
|
||||
if self.init_values is not None:
|
||||
nn.init.constant_(self.norm1.weight, self.init_values)
|
||||
nn.init.constant_(self.norm2.weight, self.init_values)
|
||||
|
||||
@ -536,7 +539,7 @@ class SwinTransformerStage(nn.Module):
|
||||
feat_size: Tuple[int, int],
|
||||
window_size: Tuple[int, int],
|
||||
mlp_ratio: float = 4.0,
|
||||
init_values: float = 0.0,
|
||||
init_values: Optional[float] = 0.0,
|
||||
drop: float = 0.0,
|
||||
drop_attn: float = 0.0,
|
||||
drop_path: Union[List[float], float] = 0.0,
|
||||
@ -650,7 +653,7 @@ class SwinTransformerV2Cr(nn.Module):
|
||||
depths: Tuple[int, ...] = (2, 2, 6, 2),
|
||||
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
|
||||
mlp_ratio: float = 4.0,
|
||||
init_values: float = 0.0,
|
||||
init_values: Optional[float] = 0.,
|
||||
drop_rate: float = 0.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
@ -808,6 +811,21 @@ def init_weights(module: nn.Module, name: str = ''):
|
||||
module.init_weights()
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
out_dict = {}
|
||||
if 'model' in state_dict:
|
||||
# For deit models
|
||||
state_dict = state_dict['model']
|
||||
for k, v in state_dict.items():
|
||||
if 'tau' in k:
|
||||
# convert old tau based checkpoints -> logit_scale (inverse)
|
||||
v = torch.log(1 / v)
|
||||
k = k.replace('tau', 'logit_scale')
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
@ -890,7 +908,6 @@ def swinv2_cr_small_ns_224(pretrained=False, **kwargs):
|
||||
embed_dim=96,
|
||||
depths=(2, 2, 18, 2),
|
||||
num_heads=(3, 6, 12, 24),
|
||||
init_values=1e-5,
|
||||
extra_norm_stage=True,
|
||||
**kwargs
|
||||
)
|
||||
@ -928,7 +945,6 @@ def swinv2_cr_base_ns_224(pretrained=False, **kwargs):
|
||||
embed_dim=128,
|
||||
depths=(2, 2, 18, 2),
|
||||
num_heads=(4, 8, 16, 32),
|
||||
init_values=1e-6,
|
||||
extra_norm_stage=True,
|
||||
**kwargs
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user