mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Enhancement] Support hrnet frozen stage (#743)
* support hrnet frozen stage * support hrnet frozen stage
This commit is contained in:
parent
2f3f027c3d
commit
778961dd2e
@ -230,6 +230,8 @@ class HRNet(BaseModule):
|
||||
and its variants only.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Default: -1.
|
||||
zero_init_residual (bool): whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
@ -285,6 +287,7 @@ class HRNet(BaseModule):
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
frozen_stages=-1,
|
||||
zero_init_residual=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
@ -315,6 +318,7 @@ class HRNet(BaseModule):
|
||||
self.norm_cfg = norm_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
# stem net
|
||||
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
|
||||
@ -388,6 +392,8 @@ class HRNet(BaseModule):
|
||||
self.stage4, pre_stage_channels = self._make_stage(
|
||||
self.stage4_cfg, num_channels)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
@ -534,6 +540,32 @@ class HRNet(BaseModule):
|
||||
|
||||
return Sequential(*hr_modules), in_channels
|
||||
|
||||
def _freeze_stages(self):
|
||||
"""Freeze stages param and norm stats."""
|
||||
if self.frozen_stages >= 0:
|
||||
|
||||
self.norm1.eval()
|
||||
self.norm2.eval()
|
||||
for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
if i == 1:
|
||||
m = getattr(self, f'layer{i}')
|
||||
t = getattr(self, f'transition{i}')
|
||||
elif i == 4:
|
||||
m = getattr(self, f'stage{i}')
|
||||
else:
|
||||
m = getattr(self, f'stage{i}')
|
||||
t = getattr(self, f'transition{i}')
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
t.eval()
|
||||
for param in t.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
@ -575,6 +607,7 @@ class HRNet(BaseModule):
|
||||
"""Convert the model into training mode will keeping the normalization
|
||||
layer freezed."""
|
||||
super(HRNet, self).train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
|
63
tests/test_models/test_backbones/test_hrnet.py
Normal file
63
tests/test_models/test_backbones/test_hrnet.py
Normal file
@ -0,0 +1,63 @@
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.models.backbones import HRNet
|
||||
|
||||
|
||||
def test_hrnet_backbone():
|
||||
# Test HRNET with two stage frozen
|
||||
|
||||
extra = dict(
|
||||
stage1=dict(
|
||||
num_modules=1,
|
||||
num_branches=1,
|
||||
block='BOTTLENECK',
|
||||
num_blocks=(4, ),
|
||||
num_channels=(64, )),
|
||||
stage2=dict(
|
||||
num_modules=1,
|
||||
num_branches=2,
|
||||
block='BASIC',
|
||||
num_blocks=(4, 4),
|
||||
num_channels=(32, 64)),
|
||||
stage3=dict(
|
||||
num_modules=4,
|
||||
num_branches=3,
|
||||
block='BASIC',
|
||||
num_blocks=(4, 4, 4),
|
||||
num_channels=(32, 64, 128)),
|
||||
stage4=dict(
|
||||
num_modules=3,
|
||||
num_branches=4,
|
||||
block='BASIC',
|
||||
num_blocks=(4, 4, 4, 4),
|
||||
num_channels=(32, 64, 128, 256)))
|
||||
frozen_stages = 2
|
||||
model = HRNet(extra, frozen_stages=frozen_stages)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert model.norm1.training is False
|
||||
|
||||
for layer in [model.conv1, model.norm1]:
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
if i == 1:
|
||||
layer = getattr(model, f'layer{i}')
|
||||
transition = getattr(model, f'transition{i}')
|
||||
elif i == 4:
|
||||
layer = getattr(model, f'stage{i}')
|
||||
else:
|
||||
layer = getattr(model, f'stage{i}')
|
||||
transition = getattr(model, f'transition{i}')
|
||||
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
for mod in transition.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in transition.parameters():
|
||||
assert param.requires_grad is False
|
Loading…
x
Reference in New Issue
Block a user