mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Delete frozen parameters when using paramwise_cfg
(#1441)
This commit is contained in:
parent
9ecced821b
commit
acbc5e46dc
@ -213,7 +213,10 @@ class DefaultOptimWrapperConstructor:
|
|||||||
level=logging.WARNING)
|
level=logging.WARNING)
|
||||||
continue
|
continue
|
||||||
if not param.requires_grad:
|
if not param.requires_grad:
|
||||||
params.append(param_group)
|
print_log((f'{prefix}.{name} is skipped since its '
|
||||||
|
f'requires_grad={param.requires_grad}'),
|
||||||
|
logger='current',
|
||||||
|
level=logging.WARNING)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# if the parameter match one of the custom keys, ignore other rules
|
# if the parameter match one of the custom keys, ignore other rules
|
||||||
|
@ -549,7 +549,8 @@ class TestBuilder(TestCase):
|
|||||||
weight_decay=self.base_wd,
|
weight_decay=self.base_wd,
|
||||||
momentum=self.momentum))
|
momentum=self.momentum))
|
||||||
paramwise_cfg = dict()
|
paramwise_cfg = dict()
|
||||||
optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg)
|
optim_constructor = DefaultOptimWrapperConstructor(
|
||||||
|
optim_wrapper_cfg, paramwise_cfg)
|
||||||
optim_wrapper = optim_constructor(model)
|
optim_wrapper = optim_constructor(model)
|
||||||
self._check_default_optimizer(optim_wrapper.optimizer, model)
|
self._check_default_optimizer(optim_wrapper.optimizer, model)
|
||||||
|
|
||||||
@ -591,23 +592,16 @@ class TestBuilder(TestCase):
|
|||||||
dwconv_decay_mult=0.1,
|
dwconv_decay_mult=0.1,
|
||||||
dcn_offset_lr_mult=0.1)
|
dcn_offset_lr_mult=0.1)
|
||||||
|
|
||||||
for param in self.model.parameters():
|
self.model.conv1.requires_grad_(False)
|
||||||
param.requires_grad = False
|
|
||||||
optim_constructor = DefaultOptimWrapperConstructor(
|
optim_constructor = DefaultOptimWrapperConstructor(
|
||||||
optim_wrapper_cfg, paramwise_cfg)
|
optim_wrapper_cfg, paramwise_cfg)
|
||||||
optim_wrapper = optim_constructor(self.model)
|
optim_wrapper = optim_constructor(self.model)
|
||||||
optimizer = optim_wrapper.optimizer
|
|
||||||
param_groups = optimizer.param_groups
|
all_params = []
|
||||||
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)
|
for pg in optim_wrapper.param_groups:
|
||||||
assert optimizer.defaults['lr'] == self.base_lr
|
all_params.extend(map(id, pg['params']))
|
||||||
assert optimizer.defaults['momentum'] == self.momentum
|
self.assertNotIn(id(self.model.conv1.weight), all_params)
|
||||||
assert optimizer.defaults['weight_decay'] == self.base_wd
|
self.assertIn(id(self.model.conv2.weight), all_params)
|
||||||
for i, (name, param) in enumerate(self.model.named_parameters()):
|
|
||||||
param_group = param_groups[i]
|
|
||||||
assert torch.equal(param_group['params'][0], param)
|
|
||||||
assert param_group['momentum'] == self.momentum
|
|
||||||
assert param_group['lr'] == self.base_lr
|
|
||||||
assert param_group['weight_decay'] == self.base_wd
|
|
||||||
|
|
||||||
def test_default_optimizer_constructor_bypass_duplicate(self):
|
def test_default_optimizer_constructor_bypass_duplicate(self):
|
||||||
# paramwise_cfg with bypass_duplicate option
|
# paramwise_cfg with bypass_duplicate option
|
||||||
@ -663,10 +657,8 @@ class TestBuilder(TestCase):
|
|||||||
optim_wrapper = optim_constructor(model)
|
optim_wrapper = optim_constructor(model)
|
||||||
model_parameters = list(model.parameters())
|
model_parameters = list(model.parameters())
|
||||||
num_params = 14 if MMCV_FULL_AVAILABLE else 11
|
num_params = 14 if MMCV_FULL_AVAILABLE else 11
|
||||||
assert len(optim_wrapper.optimizer.param_groups) == len(
|
assert len(optim_wrapper.optimizer.param_groups
|
||||||
model_parameters) == num_params
|
) == len(model_parameters) - 1 == num_params - 1
|
||||||
self._check_sgd_optimizer(optim_wrapper.optimizer, model,
|
|
||||||
**paramwise_cfg)
|
|
||||||
|
|
||||||
def test_default_optimizer_constructor_custom_key(self):
|
def test_default_optimizer_constructor_custom_key(self):
|
||||||
# test DefaultOptimWrapperConstructor with custom_keys and
|
# test DefaultOptimWrapperConstructor with custom_keys and
|
||||||
|
Loading…
x
Reference in New Issue
Block a user