[Enhance] Learning rate in log can show the base learning rate of optimizer (#1019)

pull/766/merge
Akide Liu 2023-06-08 21:21:15 +09:30 committed by GitHub
parent eb2dc671e9
commit 94e7a3bb57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 90 additions and 13 deletions

View File

@ -133,7 +133,7 @@ class AmpOptimWrapper(OptimWrapper):
:attr:`optimizer`.
"""
# save state_dict of loss_scaler
state_dict = self.optimizer.state_dict()
state_dict = super().state_dict()
state_dict['loss_scaler'] = self.loss_scaler.state_dict()
return state_dict
@ -151,6 +151,11 @@ class AmpOptimWrapper(OptimWrapper):
"""
if 'loss_scaler' in state_dict:
self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler'))
if 'base_param_settings' in state_dict:
self.base_param_settings = state_dict.pop('base_param_settings')
# load state_dict of optimizer
self.optimizer.load_state_dict(state_dict)
@contextmanager

View File

@ -161,6 +161,20 @@ class OptimWrapper:
# the loss factor will always be the same as `_accumulative_counts`.
self._remainder_counts = -1
# The Following code is used to initialize `base_param_settings`.
# `base_param_settings` is used to store the parameters that are not
# updated by the optimizer.
# The `base_param_settings` used for tracking the base learning in the
# optimizer. If the optimizer has multiple parameter groups, this
# params will not be scaled by the loss factor.
if len(optimizer.param_groups) > 1:
self.base_param_settings = {
'params': torch.tensor([0.0], dtype=torch.float)
}
self.base_param_settings.update(**self.optimizer.defaults)
else:
self.base_param_settings = None # type: ignore
def update_params(self,
loss: torch.Tensor,
step_kwargs: Optional[Dict] = None,
@ -251,7 +265,10 @@ class OptimWrapper:
Returns:
dict: The state dictionary of :attr:`optimizer`.
"""
return self.optimizer.state_dict()
state_dict = self.optimizer.state_dict()
if self.base_param_settings is not None:
state_dict['base_param_settings'] = self.base_param_settings
return state_dict
def load_state_dict(self, state_dict: dict) -> None:
"""A wrapper of ``Optimizer.load_state_dict``. load the state dict of
@ -265,6 +282,12 @@ class OptimWrapper:
Args:
state_dict (dict): The state dictionary of :attr:`optimizer`.
"""
base_param_settings = state_dict.pop('base_param_settings', None)
if base_param_settings is not None:
self.base_param_settings = base_param_settings
# load state_dict of optimizer
self.optimizer.load_state_dict(state_dict)
@property
@ -276,7 +299,10 @@ class OptimWrapper:
Returns:
dict: the ``param_groups`` of :attr:`optimizer`.
"""
return self.optimizer.param_groups
if self.base_param_settings is not None:
return self.optimizer.param_groups + [self.base_param_settings]
else:
return self.optimizer.param_groups
@property
def defaults(self) -> dict:
@ -295,10 +321,16 @@ class OptimWrapper:
Provide unified interface to get learning rate of optimizer.
Returns:
Dict[str, List[float]]: Learning rate of the optimizer.
Dict[str, List[float]]:
param_groups learning rate of the optimizer.
"""
lr = [group['lr'] for group in self.param_groups]
return dict(lr=lr)
res = {}
if self.base_param_settings is not None:
res['base_lr'] = [self.base_param_settings['lr']]
res['lr'] = [group['lr'] for group in self.optimizer.param_groups]
return res
def get_momentum(self) -> Dict[str, List[float]]:
"""Get the momentum of the optimizer.
@ -309,7 +341,7 @@ class OptimWrapper:
Dict[str, List[float]]: Momentum of the optimizer.
"""
momentum = []
for group in self.param_groups:
for group in self.optimizer.param_groups:
# Get momentum of SGD.
if 'momentum' in group.keys():
momentum.append(group['momentum'])

View File

@ -115,7 +115,10 @@ class OptimWrapperDict(OptimWrapper):
"""
lr_dict = dict()
for name, optim_wrapper in self.optim_wrappers.items():
lr_dict[f'{name}.lr'] = optim_wrapper.get_lr()['lr']
inner_lr_dict = optim_wrapper.get_lr()
if 'base_lr' in inner_lr_dict:
lr_dict[f'{name}.base_lr'] = inner_lr_dict['base_lr']
lr_dict[f'{name}.lr'] = inner_lr_dict['lr']
return lr_dict
def get_momentum(self) -> Dict[str, List[float]]:

View File

@ -1409,11 +1409,28 @@ class ReduceOnPlateauParamScheduler(_ParamScheduler):
raise ValueError('Factor should be < 1.0.')
self.factor = factor
# This code snippet handles compatibility with the optimizer wrapper.
# The optimizer wrapper includes an additional parameter to record the
# base learning rate (lr) which is not affected by the paramwise_cfg.
# By retrieving the base lr, we can obtain the actual base lr that
# reflects the learning progress.
if isinstance(optimizer, OptimWrapper):
raw_optimizer = optimizer.optimizer
else:
raw_optimizer = optimizer
if isinstance(min_value, (list, tuple)):
if len(min_value) != len(optimizer.param_groups):
if len(min_value) != len(raw_optimizer.param_groups):
raise ValueError('expected {} min_lrs, got {}'.format(
len(optimizer.param_groups), len(min_value)))
len(raw_optimizer.param_groups), len(min_value)))
self.min_values = list(min_value)
# Consider the `min_value` of the last param_groups
# as the base setting. And we only add this value when
# the optimizer is OptimWrapper.
if isinstance(optimizer, OptimWrapper) and \
optimizer.base_param_settings is not None:
self.min_values.append(self.min_values[-1])
else:
self.min_values = [min_value] * len( # type: ignore
optimizer.param_groups)

View File

@ -14,7 +14,8 @@ from torch.optim import SGD, Adam, Optimizer
from mmengine.dist import all_gather
from mmengine.logging import MessageHub, MMLogger
from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper
from mmengine.optim import (AmpOptimWrapper, ApexOptimWrapper,
DefaultOptimWrapperConstructor, OptimWrapper)
from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION
@ -174,6 +175,15 @@ class TestOptimWrapper(MultiProcessTestCase):
optim = SGD(model.parameters(), lr=0.1)
optim_wrapper = OptimWrapper(optim)
self.assertEqual(optim_wrapper.get_lr(), dict(lr=[0.1]))
model = ToyModel()
optimizer_cfg = dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.1))
paramwise_cfg = dict(custom_keys={'conv1.weight': dict(lr_mult=0.1)})
optim_constructor = DefaultOptimWrapperConstructor(
optimizer_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(model)
self.assertEqual(optim_wrapper.get_lr(),
dict(base_lr=[0.1], lr=[0.1 * 0.1] + [0.1] * 5))
def test_get_momentum(self):
# Get momentum from SGD
@ -194,12 +204,18 @@ class TestOptimWrapper(MultiProcessTestCase):
def test_zero_grad(self):
optimizer = MagicMock(spec=Optimizer)
optimizer.defaults = {
} # adjust this line according to what OptimWrapper expects
optimizer.param_groups = [{}]
optim_wrapper = OptimWrapper(optimizer)
optim_wrapper.zero_grad()
optimizer.zero_grad.assert_called()
def test_step(self):
optimizer = MagicMock(spec=Optimizer)
optimizer.defaults = {
} # adjust this line according to what OptimWrapper expects
optimizer.param_groups = [{}]
optim_wrapper = OptimWrapper(optimizer)
optim_wrapper.step()
optimizer.step.assert_called()
@ -237,7 +253,6 @@ class TestOptimWrapper(MultiProcessTestCase):
model = ToyModel()
optimizer = SGD(model.parameters(), lr=0.1)
optim_wrapper.load_state_dict(optimizer.state_dict())
self.assertEqual(optim_wrapper.state_dict(), optimizer.state_dict())
def test_param_groups(self):
@ -447,6 +462,9 @@ class TestAmpOptimWrapper(TestCase):
if dtype == 'bfloat16' and not bf16_supported():
raise unittest.SkipTest('bfloat16 not supported by device')
optimizer = MagicMock(spec=Optimizer)
optimizer.defaults = {
} # adjust this line according to what OptimWrapper expects
optimizer.param_groups = [{}]
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer, dtype=dtype)
amp_optim_wrapper.loss_scaler = MagicMock()
amp_optim_wrapper.step()
@ -504,7 +522,6 @@ class TestAmpOptimWrapper(TestCase):
# Test load from optimizer
optimizer = SGD(self.model.parameters(), lr=0.1)
amp_optim_wrapper.load_state_dict(optimizer.state_dict())
self.assertDictEqual(optimizer.state_dict(),
amp_optim_wrapper.optimizer.state_dict())
# Test load from optim_wrapper

View File

@ -167,6 +167,9 @@ class TestParameterScheduler(TestCase):
self.optimizer, param_name='lr', step_size=3, gamma=0.1)
for epoch in range(epochs):
result = scheduler.get_last_value()
if isinstance(scheduler.optimizer, OptimWrapper) \
and scheduler.optimizer.base_param_settings is not None:
result.pop()
self.optimizer.step()
scheduler.step()
target = [t[epoch] for t in targets]