[Feature] Add the frozen function for Swin Transformer model. (#574)

* Add the frozen function for Swin Transformer model

* add frozen parameter for swin transformer model

* add norm_eval parameter

* Delete =11.1

* Delete =418,driver

* delete _BatchNorm

* remove LayerNorm , add _BatchNorm

* unifying the style of frozen function refer ResNet

* Improve docs and add unit tests.

Co-authored-by: cxiang26 <cq.xiang@foxmail.com>
Co-authored-by: mzr1996 <mzr1996@163.com>
pull/580/head
fangxu 2021-12-07 11:58:14 +08:00 committed by GitHub
parent ca0cf41df9
commit 0aa789f3c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 0 deletions

View File

@ -9,6 +9,7 @@ from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
from ..utils import PatchEmbed, PatchMerging, ShiftWindowMSA
@ -235,6 +236,11 @@ class SwinTransformer(BaseBackbone):
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
auto_pad (bool): If True, auto pad feature map to fit window_size.
Defaults to False.
norm_cfg (dict, optional): Config dict for normalization layer at end
@ -291,6 +297,8 @@ class SwinTransformer(BaseBackbone):
use_abs_pos_embed=False,
auto_pad=False,
with_cp=False,
frozen_stages=-1,
norm_eval=False,
norm_cfg=dict(type='LN'),
stage_cfgs=dict(),
patch_cfg=dict(),
@ -315,6 +323,7 @@ class SwinTransformer(BaseBackbone):
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
self.auto_pad = auto_pad
self.frozen_stages = frozen_stages
_patch_cfg = {
'img_size': img_size,
@ -334,6 +343,7 @@ class SwinTransformer(BaseBackbone):
torch.zeros(1, num_patches, self.embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
self.norm_eval = norm_eval
# stochastic depth
total_depth = sum(self.depths)
@ -425,3 +435,28 @@ class SwinTransformer(BaseBackbone):
super()._load_from_state_dict(state_dict, prefix, local_metadata,
*args, **kwargs)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(0, self.frozen_stages + 1):
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
for i in self.out_indices:
if i <= self.frozen_stages:
for param in getattr(self, f'norm{i}').parameters():
param.requires_grad = False
def train(self, mode=True):
super(SwinTransformer, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()

View File

@ -7,11 +7,21 @@ import numpy as np
import pytest
import torch
from mmcv.runner import load_checkpoint, save_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import SwinTransformer
from mmcls.models.backbones.swin_transformer import SwinBlock
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_assertion():
"""Test Swin Transformer backbone."""
with pytest.raises(AssertionError):
@ -171,6 +181,41 @@ def test_structure():
assert np.isclose(block.attn.drop.drop_prob, expect_prob)
pos += 1
# Test Swin-Transformer with norm_eval=True
model = SwinTransformer(
arch='small',
norm_eval=True,
norm_cfg=dict(type='BN'),
stage_cfgs=dict(block_cfgs=dict(norm_cfg=dict(type='BN'))),
)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), False)
# Test Swin-Transformer with first stage frozen.
frozen_stages = 0
model = SwinTransformer(
arch='small', frozen_stages=frozen_stages, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
assert model.patch_embed.training is False
for param in model.patch_embed.parameters():
assert param.requires_grad is False
for i in range(frozen_stages + 1):
stage = model.stages[i]
for param in stage.parameters():
assert param.requires_grad is False
for param in model.norm0.parameters():
assert param.requires_grad is False
# the second stage should require grad.
stage = model.stages[1]
for param in stage.parameters():
assert param.requires_grad is True
for param in model.norm1.parameters():
assert param.requires_grad is True
def test_load_checkpoint():
model = SwinTransformer(arch='tiny')