[Feature] Add the frozen function for Swin Transformer model. (#574)
* Add the frozen function for Swin Transformer model * add frozen parameter for swin transformer model * add norm_eval parameter * Delete =11.1 * Delete =418,driver * delete _BatchNorm * remove LayerNorm , add _BatchNorm * unifying the style of frozen function refer ResNet * Improve docs and add unit tests. Co-authored-by: cxiang26 <cq.xiang@foxmail.com> Co-authored-by: mzr1996 <mzr1996@163.com>pull/580/head
parent
ca0cf41df9
commit
0aa789f3c3
|
@ -9,6 +9,7 @@ from mmcv.cnn import build_norm_layer
|
|||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import PatchEmbed, PatchMerging, ShiftWindowMSA
|
||||
|
@ -235,6 +236,11 @@ class SwinTransformer(BaseBackbone):
|
|||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
auto_pad (bool): If True, auto pad feature map to fit window_size.
|
||||
Defaults to False.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer at end
|
||||
|
@ -291,6 +297,8 @@ class SwinTransformer(BaseBackbone):
|
|||
use_abs_pos_embed=False,
|
||||
auto_pad=False,
|
||||
with_cp=False,
|
||||
frozen_stages=-1,
|
||||
norm_eval=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
stage_cfgs=dict(),
|
||||
patch_cfg=dict(),
|
||||
|
@ -315,6 +323,7 @@ class SwinTransformer(BaseBackbone):
|
|||
self.out_indices = out_indices
|
||||
self.use_abs_pos_embed = use_abs_pos_embed
|
||||
self.auto_pad = auto_pad
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
_patch_cfg = {
|
||||
'img_size': img_size,
|
||||
|
@ -334,6 +343,7 @@ class SwinTransformer(BaseBackbone):
|
|||
torch.zeros(1, num_patches, self.embed_dims))
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
self.norm_eval = norm_eval
|
||||
|
||||
# stochastic depth
|
||||
total_depth = sum(self.depths)
|
||||
|
@ -425,3 +435,28 @@ class SwinTransformer(BaseBackbone):
|
|||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
||||
*args, **kwargs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(0, self.frozen_stages + 1):
|
||||
m = self.stages[i]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
for i in self.out_indices:
|
||||
if i <= self.frozen_stages:
|
||||
for param in getattr(self, f'norm{i}').parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super(SwinTransformer, self).train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
|
|
@ -7,11 +7,21 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import SwinTransformer
|
||||
from mmcls.models.backbones.swin_transformer import SwinBlock
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_assertion():
|
||||
"""Test Swin Transformer backbone."""
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -171,6 +181,41 @@ def test_structure():
|
|||
assert np.isclose(block.attn.drop.drop_prob, expect_prob)
|
||||
pos += 1
|
||||
|
||||
# Test Swin-Transformer with norm_eval=True
|
||||
model = SwinTransformer(
|
||||
arch='small',
|
||||
norm_eval=True,
|
||||
norm_cfg=dict(type='BN'),
|
||||
stage_cfgs=dict(block_cfgs=dict(norm_cfg=dict(type='BN'))),
|
||||
)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test Swin-Transformer with first stage frozen.
|
||||
frozen_stages = 0
|
||||
model = SwinTransformer(
|
||||
arch='small', frozen_stages=frozen_stages, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
assert model.patch_embed.training is False
|
||||
for param in model.patch_embed.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(frozen_stages + 1):
|
||||
stage = model.stages[i]
|
||||
for param in stage.parameters():
|
||||
assert param.requires_grad is False
|
||||
for param in model.norm0.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# the second stage should require grad.
|
||||
stage = model.stages[1]
|
||||
for param in stage.parameters():
|
||||
assert param.requires_grad is True
|
||||
for param in model.norm1.parameters():
|
||||
assert param.requires_grad is True
|
||||
|
||||
|
||||
def test_load_checkpoint():
|
||||
model = SwinTransformer(arch='tiny')
|
||||
|
|
Loading…
Reference in New Issue