[Enhancement] Support hrnet frozen stage (#743)

* support hrnet frozen stage

* support hrnet frozen stage
This commit is contained in:
sshuair 2021-08-04 00:45:42 +08:00 committed by GitHub
parent 2f3f027c3d
commit 778961dd2e
2 changed files with 96 additions and 0 deletions

View File

@ -230,6 +230,8 @@ class HRNet(BaseModule):
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
pretrained (str, optional): model pretrained path. Default: None
@ -285,6 +287,7 @@ class HRNet(BaseModule):
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
frozen_stages=-1,
zero_init_residual=False,
pretrained=None,
init_cfg=None):
@ -315,6 +318,7 @@ class HRNet(BaseModule):
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.frozen_stages = frozen_stages
# stem net
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
@ -388,6 +392,8 @@ class HRNet(BaseModule):
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels)
self._freeze_stages()
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
@ -534,6 +540,32 @@ class HRNet(BaseModule):
return Sequential(*hr_modules), in_channels
def _freeze_stages(self):
"""Freeze stages param and norm stats."""
if self.frozen_stages >= 0:
self.norm1.eval()
self.norm2.eval()
for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
if i == 1:
m = getattr(self, f'layer{i}')
t = getattr(self, f'transition{i}')
elif i == 4:
m = getattr(self, f'stage{i}')
else:
m = getattr(self, f'stage{i}')
t = getattr(self, f'transition{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
t.eval()
for param in t.parameters():
param.requires_grad = False
def forward(self, x):
"""Forward function."""
@ -575,6 +607,7 @@ class HRNet(BaseModule):
"""Convert the model into training mode will keeping the normalization
layer freezed."""
super(HRNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only

View File

@ -0,0 +1,63 @@
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.models.backbones import HRNet
def test_hrnet_backbone():
# Test HRNET with two stage frozen
extra = dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256)))
frozen_stages = 2
model = HRNet(extra, frozen_stages=frozen_stages)
model.init_weights()
model.train()
assert model.norm1.training is False
for layer in [model.conv1, model.norm1]:
for param in layer.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
if i == 1:
layer = getattr(model, f'layer{i}')
transition = getattr(model, f'transition{i}')
elif i == 4:
layer = getattr(model, f'stage{i}')
else:
layer = getattr(model, f'stage{i}')
transition = getattr(model, f'transition{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
for mod in transition.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in transition.parameters():
assert param.requires_grad is False