diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index b2915bf8..85cbddba 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -12,6 +12,7 @@ This implementation is experimental and subject to change in manners that will b GitHub link above. It needs further investigation as throughput vs mem tradeoff doesn't appear beneficial. * num_heads per stage is not detailed for Huge and Giant model variants * 'Giant' is 3B params in paper but ~2.6B here despite matching paper dim + block counts +* experiments are ongoing wrt to 'main branch' norm layer use and weight init scheme Noteworthy additions over official Swin v1: * MLP relative position embedding is looking promising and adapts to different image/window sizes @@ -37,7 +38,7 @@ import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply +from .helpers import build_model_with_cfg, named_apply from .layers import DropPath, Mlp, to_2tuple, _assert from .registry import register_model from .vision_transformer import checkpoint_filter_fn @@ -67,27 +68,29 @@ default_cfgs = { 'swin_v2_cr_tiny_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), 'swin_v2_cr_tiny_224': _cfg( - url="", input_size=(3, 224, 224), crop_pct=1.0), + url="", input_size=(3, 224, 224), crop_pct=0.9), + 'swin_v2_cr_tiny_ns_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=0.9), 'swin_v2_cr_small_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), 'swin_v2_cr_small_224': _cfg( - url="", input_size=(3, 224, 224), crop_pct=1.0), + url="", input_size=(3, 224, 224), crop_pct=0.9), 'swin_v2_cr_base_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), 'swin_v2_cr_base_224': _cfg( - url="", input_size=(3, 224, 224), crop_pct=1.0), + url="", input_size=(3, 224, 224), crop_pct=0.9), 'swin_v2_cr_large_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), 'swin_v2_cr_large_224': _cfg( - url="", input_size=(3, 224, 224), crop_pct=1.0), + url="", input_size=(3, 224, 224), crop_pct=0.9), 'swin_v2_cr_huge_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), 'swin_v2_cr_huge_224': _cfg( - url="", input_size=(3, 224, 224), crop_pct=1.0), + url="", input_size=(3, 224, 224), crop_pct=0.9), 'swin_v2_cr_giant_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), 'swin_v2_cr_giant_224': _cfg( - url="", input_size=(3, 224, 224), crop_pct=1.0), + url="", input_size=(3, 224, 224), crop_pct=0.9), } @@ -175,7 +178,7 @@ class WindowMultiHeadAttention(nn.Module): hidden_features=meta_hidden_dim, out_features=num_heads, act_layer=nn.ReLU, - drop=0. # FIXME should we add stochasticity? + drop=0.1 # FIXME should there be stochasticity, appears to 'overfit' without? ) self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads))) self._make_pair_wise_relative_positions() @@ -336,7 +339,8 @@ class SwinTransformerBlock(nn.Module): self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity() - # extra norm layer mentioned for Huge/Giant models in V2 paper (FIXME may be in wrong spot?) + # Extra main branch norm layer mentioned for Huge/Giant models in V2 paper. + # Also being used as final network norm and optional stage ending norm while still in a C-last format. self.norm3 = norm_layer(dim) if extra_norm else nn.Identity() self._make_attention_mask() @@ -392,13 +396,16 @@ class SwinTransformerBlock(nn.Module): x = x.view(B, H, W, C) # cyclic shift - if any(self.shift_size): - shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) - else: - shifted_x = x + sh, sw = self.shift_size + do_shift: bool = any(self.shift_size) + if do_shift: + # FIXME PyTorch XLA needs cat impl, roll not lowered + # x = torch.cat([x[:, sh:], x[:, :sh]], dim=1) + # x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2) + x = torch.roll(x, shifts=(-sh, -sw), dims=(1, 2)) # partition windows - x_windows = window_partition(shifted_x, self.window_size) # num_windows * B, window_size, window_size, C + x_windows = window_partition(x, self.window_size) # num_windows * B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C) # W-MSA/SW-MSA @@ -406,13 +413,14 @@ class SwinTransformerBlock(nn.Module): # merge windows attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) - shifted_x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C + x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C # reverse cyclic shift - if any(self.shift_size): - x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) - else: - x = shifted_x + if do_shift: + # FIXME PyTorch XLA needs cat impl, roll not lowered + # x = torch.cat([x[:, -sh:], x[:, :-sh]], dim=1) + # x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2) + x = torch.roll(x, shifts=(sh, sw), dims=(1, 2)) x = x.view(B, L, C) return x @@ -429,7 +437,7 @@ class SwinTransformerBlock(nn.Module): # NOTE post-norm branches (op -> norm -> drop) x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x))) x = x + self.drop_path2(self.norm2(self.mlp(x))) - x = self.norm3(x) # main-branch norm enabled for some blocks (every 6 for Huge/Giant) + x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant) return x @@ -452,8 +460,10 @@ class PatchMerging(nn.Module): Returns: output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] """ - x = bchw_to_bhwc(x).unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2) - x = x.permute(0, 1, 2, 5, 4, 3).flatten(3) # permute maintains compat with ch order in official swin impl + B, C, H, W = x.shape + # unfold + BCHW -> BHWC together + # ordering, 5, 3, 1 instead of 3, 5, 1 maintains compat with original swin v1 merge + x = x.reshape(B, C, H // 2, 2, W // 2, 2).permute(0, 2, 4, 5, 3, 1).flatten(3) x = self.norm(x) x = bhwc_to_bchw(self.reduction(x)) return x @@ -497,8 +507,8 @@ class SwinTransformerStage(nn.Module): drop_attn (float): Dropout rate of attention map drop_path (float): Dropout in main path norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm - grad_checkpointing (bool): If true checkpointing is utilized extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks + extra_norm_stage (bool): End each stage with an extra norm layer in main branch sequential_attn (bool): If true sequential self-attention is performed """ @@ -515,17 +525,23 @@ class SwinTransformerStage(nn.Module): drop_attn: float = 0.0, drop_path: Union[List[float], float] = 0.0, norm_layer: Type[nn.Module] = nn.LayerNorm, - grad_checkpointing: bool = False, extra_norm_period: int = 0, + extra_norm_stage: bool = False, sequential_attn: bool = False, ) -> None: super(SwinTransformerStage, self).__init__() self.downscale: bool = downscale - self.grad_checkpointing: bool = grad_checkpointing + self.grad_checkpointing: bool = False self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity() + def _extra_norm(index): + i = index + 1 + if extra_norm_period and i % extra_norm_period == 0: + return True + return i == depth if extra_norm_stage else False + embed_dim = embed_dim * 2 if downscale else embed_dim self.blocks = nn.Sequential(*[ SwinTransformerBlock( @@ -538,7 +554,7 @@ class SwinTransformerStage(nn.Module): drop=drop, drop_attn=drop_attn, drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path, - extra_norm=not (index + 1) % extra_norm_period if extra_norm_period else False, + extra_norm=_extra_norm(index), sequential_attn=sequential_attn, norm_layer=norm_layer, ) @@ -600,9 +616,9 @@ class SwinTransformerV2Cr(nn.Module): attn_drop_rate (float): Dropout rate of attention map. Default: 0.0 drop_path_rate (float): Stochastic depth rate. Default: 0.0 norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm - grad_checkpointing (bool): If true checkpointing is utilized. Default: False + extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks in stage + extra_norm_stage (bool): End each stage with an extra norm layer in main branch sequential_attn (bool): If true sequential self-attention is performed. Default: False - use_deformable (bool): If true deformable block is used. Default: False """ def __init__( @@ -621,10 +637,11 @@ class SwinTransformerV2Cr(nn.Module): attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: Type[nn.Module] = nn.LayerNorm, - grad_checkpointing: bool = False, extra_norm_period: int = 0, + extra_norm_stage: bool = False, sequential_attn: bool = False, global_pool: str = 'avg', + weight_init='skip', **kwargs: Any ) -> None: super(SwinTransformerV2Cr, self).__init__() @@ -638,7 +655,7 @@ class SwinTransformerV2Cr(nn.Module): self.window_size: int = window_size self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1)) - self.patch_embed: nn.Module = PatchEmbed( + self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer) patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size @@ -659,8 +676,8 @@ class SwinTransformerV2Cr(nn.Module): drop=drop_rate, drop_attn=attn_drop_rate, drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], - grad_checkpointing=grad_checkpointing, extra_norm_period=extra_norm_period, + extra_norm_stage=extra_norm_stage or (index + 1) == len(depths), # last stage ends w/ norm sequential_attn=sequential_attn, norm_layer=norm_layer, ) @@ -668,12 +685,12 @@ class SwinTransformerV2Cr(nn.Module): self.stages = nn.Sequential(*stages) self.global_pool: str = global_pool - self.head: nn.Module = nn.Linear( - in_features=self.num_features, out_features=num_classes) if num_classes else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes) if num_classes else nn.Identity() - # FIXME weight init TBD, PyTorch default init appears to be working well, - # but differs from usual ViT or Swin init. - # named_apply(init_weights, self) + # current weight init skips custom init and uses pytorch layer defaults, seems to work well + # FIXME more experiments needed + if weight_init != 'skip': + named_apply(init_weights, self) def update_input_size( self, @@ -704,13 +721,28 @@ class SwinTransformerV2Cr(nn.Module): new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale), ) + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^patch_embed', # stem and embed + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore() def get_classifier(self) -> nn.Module: """Method returns the classification head of the model. Returns: head (nn.Module): Current classification head """ - head: nn.Module = self.head - return head + return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: """Method results the classification head @@ -722,8 +754,7 @@ class SwinTransformerV2Cr(nn.Module): self.num_classes: int = num_classes if global_pool is not None: self.global_pool = global_pool - self.head: nn.Module = nn.Linear( - in_features=self.num_features, out_features=num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) @@ -742,41 +773,28 @@ class SwinTransformerV2Cr(nn.Module): def init_weights(module: nn.Module, name: str = ''): - # FIXME WIP + # FIXME WIP determining if there's a better weight init if isinstance(module, nn.Linear): if 'qkv' in name: # treat the weights of Q, K, V separately val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) nn.init.uniform_(module.weight, -val, val) + elif 'head' in name: + nn.init.zeros_(module.weight) else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) -def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) +def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - model = build_model_with_cfg( - SwinTransformerV2Cr, - variant, - pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + SwinTransformerV2Cr, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, **kwargs ) - return model @@ -804,6 +822,21 @@ def swin_v2_cr_tiny_224(pretrained=False, **kwargs): return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs) +@register_model +def swin_v2_cr_tiny_ns_224(pretrained=False, **kwargs): + """Swin-T V2 CR @ 224x224, trained ImageNet-1k w/ extra stage norms. + ** Experimental, may make default if results are improved. ** + """ + model_kwargs = dict( + embed_dim=96, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + extra_norm_stage=True, + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_ns_224', pretrained=pretrained, **model_kwargs) + + @register_model def swin_v2_cr_small_384(pretrained=False, **kwargs): """Swin-S V2 CR @ 384x384, trained ImageNet-1k"""