Update SwinTransformerV2Cr post-merge, update with grad checkpointing / grad matcher
* weight compat break, activate norm3 for final block of final stage (equivalent to pre-head norm, but while still in BLC shape) * remove fold/unfold for TPU compat, add commented out roll code for TPU * add option for end of stage norm in all stages * allow weight_init to be selected between pytorch default inits and xavier / moco style vit variantpull/1014/head
parent
b049a5c5c6
commit
fe457c1996
|
@ -12,6 +12,7 @@ This implementation is experimental and subject to change in manners that will b
|
|||
GitHub link above. It needs further investigation as throughput vs mem tradeoff doesn't appear beneficial.
|
||||
* num_heads per stage is not detailed for Huge and Giant model variants
|
||||
* 'Giant' is 3B params in paper but ~2.6B here despite matching paper dim + block counts
|
||||
* experiments are ongoing wrt to 'main branch' norm layer use and weight init scheme
|
||||
|
||||
Noteworthy additions over official Swin v1:
|
||||
* MLP relative position embedding is looking promising and adapts to different image/window sizes
|
||||
|
@ -37,7 +38,7 @@ import torch.utils.checkpoint as checkpoint
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_notrace_function
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply
|
||||
from .helpers import build_model_with_cfg, named_apply
|
||||
from .layers import DropPath, Mlp, to_2tuple, _assert
|
||||
from .registry import register_model
|
||||
from .vision_transformer import checkpoint_filter_fn
|
||||
|
@ -67,27 +68,29 @@ default_cfgs = {
|
|||
'swin_v2_cr_tiny_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'swin_v2_cr_tiny_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=1.0),
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swin_v2_cr_tiny_ns_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swin_v2_cr_small_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'swin_v2_cr_small_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=1.0),
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swin_v2_cr_base_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'swin_v2_cr_base_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=1.0),
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swin_v2_cr_large_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'swin_v2_cr_large_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=1.0),
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swin_v2_cr_huge_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'swin_v2_cr_huge_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=1.0),
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swin_v2_cr_giant_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'swin_v2_cr_giant_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=1.0),
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
}
|
||||
|
||||
|
||||
|
@ -175,7 +178,7 @@ class WindowMultiHeadAttention(nn.Module):
|
|||
hidden_features=meta_hidden_dim,
|
||||
out_features=num_heads,
|
||||
act_layer=nn.ReLU,
|
||||
drop=0. # FIXME should we add stochasticity?
|
||||
drop=0.1 # FIXME should there be stochasticity, appears to 'overfit' without?
|
||||
)
|
||||
self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads)))
|
||||
self._make_pair_wise_relative_positions()
|
||||
|
@ -336,7 +339,8 @@ class SwinTransformerBlock(nn.Module):
|
|||
self.norm2 = norm_layer(dim)
|
||||
self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
# extra norm layer mentioned for Huge/Giant models in V2 paper (FIXME may be in wrong spot?)
|
||||
# Extra main branch norm layer mentioned for Huge/Giant models in V2 paper.
|
||||
# Also being used as final network norm and optional stage ending norm while still in a C-last format.
|
||||
self.norm3 = norm_layer(dim) if extra_norm else nn.Identity()
|
||||
|
||||
self._make_attention_mask()
|
||||
|
@ -392,13 +396,16 @@ class SwinTransformerBlock(nn.Module):
|
|||
x = x.view(B, H, W, C)
|
||||
|
||||
# cyclic shift
|
||||
if any(self.shift_size):
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
|
||||
else:
|
||||
shifted_x = x
|
||||
sh, sw = self.shift_size
|
||||
do_shift: bool = any(self.shift_size)
|
||||
if do_shift:
|
||||
# FIXME PyTorch XLA needs cat impl, roll not lowered
|
||||
# x = torch.cat([x[:, sh:], x[:, :sh]], dim=1)
|
||||
# x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2)
|
||||
x = torch.roll(x, shifts=(-sh, -sw), dims=(1, 2))
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # num_windows * B, window_size, window_size, C
|
||||
x_windows = window_partition(x, self.window_size) # num_windows * B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
|
@ -406,13 +413,14 @@ class SwinTransformerBlock(nn.Module):
|
|||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C
|
||||
x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if any(self.shift_size):
|
||||
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
if do_shift:
|
||||
# FIXME PyTorch XLA needs cat impl, roll not lowered
|
||||
# x = torch.cat([x[:, -sh:], x[:, :-sh]], dim=1)
|
||||
# x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2)
|
||||
x = torch.roll(x, shifts=(sh, sw), dims=(1, 2))
|
||||
|
||||
x = x.view(B, L, C)
|
||||
return x
|
||||
|
@ -429,7 +437,7 @@ class SwinTransformerBlock(nn.Module):
|
|||
# NOTE post-norm branches (op -> norm -> drop)
|
||||
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
|
||||
x = x + self.drop_path2(self.norm2(self.mlp(x)))
|
||||
x = self.norm3(x) # main-branch norm enabled for some blocks (every 6 for Huge/Giant)
|
||||
x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -452,8 +460,10 @@ class PatchMerging(nn.Module):
|
|||
Returns:
|
||||
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
|
||||
"""
|
||||
x = bchw_to_bhwc(x).unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2)
|
||||
x = x.permute(0, 1, 2, 5, 4, 3).flatten(3) # permute maintains compat with ch order in official swin impl
|
||||
B, C, H, W = x.shape
|
||||
# unfold + BCHW -> BHWC together
|
||||
# ordering, 5, 3, 1 instead of 3, 5, 1 maintains compat with original swin v1 merge
|
||||
x = x.reshape(B, C, H // 2, 2, W // 2, 2).permute(0, 2, 4, 5, 3, 1).flatten(3)
|
||||
x = self.norm(x)
|
||||
x = bhwc_to_bchw(self.reduction(x))
|
||||
return x
|
||||
|
@ -497,8 +507,8 @@ class SwinTransformerStage(nn.Module):
|
|||
drop_attn (float): Dropout rate of attention map
|
||||
drop_path (float): Dropout in main path
|
||||
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
|
||||
grad_checkpointing (bool): If true checkpointing is utilized
|
||||
extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks
|
||||
extra_norm_stage (bool): End each stage with an extra norm layer in main branch
|
||||
sequential_attn (bool): If true sequential self-attention is performed
|
||||
"""
|
||||
|
||||
|
@ -515,17 +525,23 @@ class SwinTransformerStage(nn.Module):
|
|||
drop_attn: float = 0.0,
|
||||
drop_path: Union[List[float], float] = 0.0,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
grad_checkpointing: bool = False,
|
||||
extra_norm_period: int = 0,
|
||||
extra_norm_stage: bool = False,
|
||||
sequential_attn: bool = False,
|
||||
) -> None:
|
||||
super(SwinTransformerStage, self).__init__()
|
||||
self.downscale: bool = downscale
|
||||
self.grad_checkpointing: bool = grad_checkpointing
|
||||
self.grad_checkpointing: bool = False
|
||||
self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size
|
||||
|
||||
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity()
|
||||
|
||||
def _extra_norm(index):
|
||||
i = index + 1
|
||||
if extra_norm_period and i % extra_norm_period == 0:
|
||||
return True
|
||||
return i == depth if extra_norm_stage else False
|
||||
|
||||
embed_dim = embed_dim * 2 if downscale else embed_dim
|
||||
self.blocks = nn.Sequential(*[
|
||||
SwinTransformerBlock(
|
||||
|
@ -538,7 +554,7 @@ class SwinTransformerStage(nn.Module):
|
|||
drop=drop,
|
||||
drop_attn=drop_attn,
|
||||
drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,
|
||||
extra_norm=not (index + 1) % extra_norm_period if extra_norm_period else False,
|
||||
extra_norm=_extra_norm(index),
|
||||
sequential_attn=sequential_attn,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
|
@ -600,9 +616,9 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
attn_drop_rate (float): Dropout rate of attention map. Default: 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
|
||||
grad_checkpointing (bool): If true checkpointing is utilized. Default: False
|
||||
extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks in stage
|
||||
extra_norm_stage (bool): End each stage with an extra norm layer in main branch
|
||||
sequential_attn (bool): If true sequential self-attention is performed. Default: False
|
||||
use_deformable (bool): If true deformable block is used. Default: False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -621,10 +637,11 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
attn_drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
grad_checkpointing: bool = False,
|
||||
extra_norm_period: int = 0,
|
||||
extra_norm_stage: bool = False,
|
||||
sequential_attn: bool = False,
|
||||
global_pool: str = 'avg',
|
||||
weight_init='skip',
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
super(SwinTransformerV2Cr, self).__init__()
|
||||
|
@ -638,7 +655,7 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
self.window_size: int = window_size
|
||||
self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
|
||||
self.patch_embed: nn.Module = PatchEmbed(
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
||||
embed_dim=embed_dim, norm_layer=norm_layer)
|
||||
patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size
|
||||
|
@ -659,8 +676,8 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
drop=drop_rate,
|
||||
drop_attn=attn_drop_rate,
|
||||
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
|
||||
grad_checkpointing=grad_checkpointing,
|
||||
extra_norm_period=extra_norm_period,
|
||||
extra_norm_stage=extra_norm_stage or (index + 1) == len(depths), # last stage ends w/ norm
|
||||
sequential_attn=sequential_attn,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
|
@ -668,12 +685,12 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.global_pool: str = global_pool
|
||||
self.head: nn.Module = nn.Linear(
|
||||
in_features=self.num_features, out_features=num_classes) if num_classes else nn.Identity()
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes else nn.Identity()
|
||||
|
||||
# FIXME weight init TBD, PyTorch default init appears to be working well,
|
||||
# but differs from usual ViT or Swin init.
|
||||
# named_apply(init_weights, self)
|
||||
# current weight init skips custom init and uses pytorch layer defaults, seems to work well
|
||||
# FIXME more experiments needed
|
||||
if weight_init != 'skip':
|
||||
named_apply(init_weights, self)
|
||||
|
||||
def update_input_size(
|
||||
self,
|
||||
|
@ -704,13 +721,28 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale),
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
return dict(
|
||||
stem=r'^patch_embed', # stem and embed
|
||||
blocks=r'^stages\.(\d+)' if coarse else [
|
||||
(r'^stages\.(\d+).downsample', (0,)),
|
||||
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
|
||||
]
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
for s in self.stages:
|
||||
s.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore()
|
||||
def get_classifier(self) -> nn.Module:
|
||||
"""Method returns the classification head of the model.
|
||||
Returns:
|
||||
head (nn.Module): Current classification head
|
||||
"""
|
||||
head: nn.Module = self.head
|
||||
return head
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
|
||||
"""Method results the classification head
|
||||
|
@ -722,8 +754,7 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
self.num_classes: int = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
self.head: nn.Module = nn.Linear(
|
||||
in_features=self.num_features, out_features=num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
|
@ -742,41 +773,28 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
|
||||
|
||||
def init_weights(module: nn.Module, name: str = ''):
|
||||
# FIXME WIP
|
||||
# FIXME WIP determining if there's a better weight init
|
||||
if isinstance(module, nn.Linear):
|
||||
if 'qkv' in name:
|
||||
# treat the weights of Q, K, V separately
|
||||
val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
|
||||
nn.init.uniform_(module.weight, -val, val)
|
||||
elif 'head' in name:
|
||||
nn.init.zeros_(module.weight)
|
||||
else:
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs):
|
||||
if default_cfg is None:
|
||||
default_cfg = deepcopy(default_cfgs[variant])
|
||||
overlay_external_default_cfg(default_cfg, kwargs)
|
||||
default_num_classes = default_cfg['num_classes']
|
||||
default_img_size = default_cfg['input_size'][-2:]
|
||||
|
||||
num_classes = kwargs.pop('num_classes', default_num_classes)
|
||||
img_size = kwargs.pop('img_size', default_img_size)
|
||||
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
model = build_model_with_cfg(
|
||||
SwinTransformerV2Cr,
|
||||
variant,
|
||||
pretrained,
|
||||
default_cfg=default_cfg,
|
||||
img_size=img_size,
|
||||
num_classes=num_classes,
|
||||
SwinTransformerV2Cr, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
@ -804,6 +822,21 @@ def swin_v2_cr_tiny_224(pretrained=False, **kwargs):
|
|||
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_tiny_ns_224(pretrained=False, **kwargs):
|
||||
"""Swin-T V2 CR @ 224x224, trained ImageNet-1k w/ extra stage norms.
|
||||
** Experimental, may make default if results are improved. **
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
embed_dim=96,
|
||||
depths=(2, 2, 6, 2),
|
||||
num_heads=(3, 6, 12, 24),
|
||||
extra_norm_stage=True,
|
||||
**kwargs
|
||||
)
|
||||
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_ns_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_small_384(pretrained=False, **kwargs):
|
||||
"""Swin-S V2 CR @ 384x384, trained ImageNet-1k"""
|
||||
|
|
Loading…
Reference in New Issue