[Enhance] Learning rate in log can show the base learning rate of optimizer (#1019)
parent
eb2dc671e9
commit
94e7a3bb57
|
@ -133,7 +133,7 @@ class AmpOptimWrapper(OptimWrapper):
|
|||
:attr:`optimizer`.
|
||||
"""
|
||||
# save state_dict of loss_scaler
|
||||
state_dict = self.optimizer.state_dict()
|
||||
state_dict = super().state_dict()
|
||||
state_dict['loss_scaler'] = self.loss_scaler.state_dict()
|
||||
return state_dict
|
||||
|
||||
|
@ -151,6 +151,11 @@ class AmpOptimWrapper(OptimWrapper):
|
|||
"""
|
||||
if 'loss_scaler' in state_dict:
|
||||
self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler'))
|
||||
|
||||
if 'base_param_settings' in state_dict:
|
||||
self.base_param_settings = state_dict.pop('base_param_settings')
|
||||
|
||||
# load state_dict of optimizer
|
||||
self.optimizer.load_state_dict(state_dict)
|
||||
|
||||
@contextmanager
|
||||
|
|
|
@ -161,6 +161,20 @@ class OptimWrapper:
|
|||
# the loss factor will always be the same as `_accumulative_counts`.
|
||||
self._remainder_counts = -1
|
||||
|
||||
# The Following code is used to initialize `base_param_settings`.
|
||||
# `base_param_settings` is used to store the parameters that are not
|
||||
# updated by the optimizer.
|
||||
# The `base_param_settings` used for tracking the base learning in the
|
||||
# optimizer. If the optimizer has multiple parameter groups, this
|
||||
# params will not be scaled by the loss factor.
|
||||
if len(optimizer.param_groups) > 1:
|
||||
self.base_param_settings = {
|
||||
'params': torch.tensor([0.0], dtype=torch.float)
|
||||
}
|
||||
self.base_param_settings.update(**self.optimizer.defaults)
|
||||
else:
|
||||
self.base_param_settings = None # type: ignore
|
||||
|
||||
def update_params(self,
|
||||
loss: torch.Tensor,
|
||||
step_kwargs: Optional[Dict] = None,
|
||||
|
@ -251,7 +265,10 @@ class OptimWrapper:
|
|||
Returns:
|
||||
dict: The state dictionary of :attr:`optimizer`.
|
||||
"""
|
||||
return self.optimizer.state_dict()
|
||||
state_dict = self.optimizer.state_dict()
|
||||
if self.base_param_settings is not None:
|
||||
state_dict['base_param_settings'] = self.base_param_settings
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""A wrapper of ``Optimizer.load_state_dict``. load the state dict of
|
||||
|
@ -265,6 +282,12 @@ class OptimWrapper:
|
|||
Args:
|
||||
state_dict (dict): The state dictionary of :attr:`optimizer`.
|
||||
"""
|
||||
base_param_settings = state_dict.pop('base_param_settings', None)
|
||||
|
||||
if base_param_settings is not None:
|
||||
self.base_param_settings = base_param_settings
|
||||
|
||||
# load state_dict of optimizer
|
||||
self.optimizer.load_state_dict(state_dict)
|
||||
|
||||
@property
|
||||
|
@ -276,7 +299,10 @@ class OptimWrapper:
|
|||
Returns:
|
||||
dict: the ``param_groups`` of :attr:`optimizer`.
|
||||
"""
|
||||
return self.optimizer.param_groups
|
||||
if self.base_param_settings is not None:
|
||||
return self.optimizer.param_groups + [self.base_param_settings]
|
||||
else:
|
||||
return self.optimizer.param_groups
|
||||
|
||||
@property
|
||||
def defaults(self) -> dict:
|
||||
|
@ -295,10 +321,16 @@ class OptimWrapper:
|
|||
Provide unified interface to get learning rate of optimizer.
|
||||
|
||||
Returns:
|
||||
Dict[str, List[float]]: Learning rate of the optimizer.
|
||||
Dict[str, List[float]]:
|
||||
param_groups learning rate of the optimizer.
|
||||
"""
|
||||
lr = [group['lr'] for group in self.param_groups]
|
||||
return dict(lr=lr)
|
||||
res = {}
|
||||
if self.base_param_settings is not None:
|
||||
res['base_lr'] = [self.base_param_settings['lr']]
|
||||
|
||||
res['lr'] = [group['lr'] for group in self.optimizer.param_groups]
|
||||
|
||||
return res
|
||||
|
||||
def get_momentum(self) -> Dict[str, List[float]]:
|
||||
"""Get the momentum of the optimizer.
|
||||
|
@ -309,7 +341,7 @@ class OptimWrapper:
|
|||
Dict[str, List[float]]: Momentum of the optimizer.
|
||||
"""
|
||||
momentum = []
|
||||
for group in self.param_groups:
|
||||
for group in self.optimizer.param_groups:
|
||||
# Get momentum of SGD.
|
||||
if 'momentum' in group.keys():
|
||||
momentum.append(group['momentum'])
|
||||
|
|
|
@ -115,7 +115,10 @@ class OptimWrapperDict(OptimWrapper):
|
|||
"""
|
||||
lr_dict = dict()
|
||||
for name, optim_wrapper in self.optim_wrappers.items():
|
||||
lr_dict[f'{name}.lr'] = optim_wrapper.get_lr()['lr']
|
||||
inner_lr_dict = optim_wrapper.get_lr()
|
||||
if 'base_lr' in inner_lr_dict:
|
||||
lr_dict[f'{name}.base_lr'] = inner_lr_dict['base_lr']
|
||||
lr_dict[f'{name}.lr'] = inner_lr_dict['lr']
|
||||
return lr_dict
|
||||
|
||||
def get_momentum(self) -> Dict[str, List[float]]:
|
||||
|
|
|
@ -1409,11 +1409,28 @@ class ReduceOnPlateauParamScheduler(_ParamScheduler):
|
|||
raise ValueError('Factor should be < 1.0.')
|
||||
self.factor = factor
|
||||
|
||||
# This code snippet handles compatibility with the optimizer wrapper.
|
||||
# The optimizer wrapper includes an additional parameter to record the
|
||||
# base learning rate (lr) which is not affected by the paramwise_cfg.
|
||||
# By retrieving the base lr, we can obtain the actual base lr that
|
||||
# reflects the learning progress.
|
||||
if isinstance(optimizer, OptimWrapper):
|
||||
raw_optimizer = optimizer.optimizer
|
||||
else:
|
||||
raw_optimizer = optimizer
|
||||
|
||||
if isinstance(min_value, (list, tuple)):
|
||||
if len(min_value) != len(optimizer.param_groups):
|
||||
if len(min_value) != len(raw_optimizer.param_groups):
|
||||
raise ValueError('expected {} min_lrs, got {}'.format(
|
||||
len(optimizer.param_groups), len(min_value)))
|
||||
len(raw_optimizer.param_groups), len(min_value)))
|
||||
self.min_values = list(min_value)
|
||||
# Consider the `min_value` of the last param_groups
|
||||
# as the base setting. And we only add this value when
|
||||
# the optimizer is OptimWrapper.
|
||||
if isinstance(optimizer, OptimWrapper) and \
|
||||
optimizer.base_param_settings is not None:
|
||||
self.min_values.append(self.min_values[-1])
|
||||
|
||||
else:
|
||||
self.min_values = [min_value] * len( # type: ignore
|
||||
optimizer.param_groups)
|
||||
|
|
|
@ -14,7 +14,8 @@ from torch.optim import SGD, Adam, Optimizer
|
|||
|
||||
from mmengine.dist import all_gather
|
||||
from mmengine.logging import MessageHub, MMLogger
|
||||
from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper
|
||||
from mmengine.optim import (AmpOptimWrapper, ApexOptimWrapper,
|
||||
DefaultOptimWrapperConstructor, OptimWrapper)
|
||||
from mmengine.testing import assert_allclose
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
@ -174,6 +175,15 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||
optim = SGD(model.parameters(), lr=0.1)
|
||||
optim_wrapper = OptimWrapper(optim)
|
||||
self.assertEqual(optim_wrapper.get_lr(), dict(lr=[0.1]))
|
||||
model = ToyModel()
|
||||
optimizer_cfg = dict(
|
||||
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.1))
|
||||
paramwise_cfg = dict(custom_keys={'conv1.weight': dict(lr_mult=0.1)})
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optimizer_cfg, paramwise_cfg)
|
||||
optim_wrapper = optim_constructor(model)
|
||||
self.assertEqual(optim_wrapper.get_lr(),
|
||||
dict(base_lr=[0.1], lr=[0.1 * 0.1] + [0.1] * 5))
|
||||
|
||||
def test_get_momentum(self):
|
||||
# Get momentum from SGD
|
||||
|
@ -194,12 +204,18 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||
|
||||
def test_zero_grad(self):
|
||||
optimizer = MagicMock(spec=Optimizer)
|
||||
optimizer.defaults = {
|
||||
} # adjust this line according to what OptimWrapper expects
|
||||
optimizer.param_groups = [{}]
|
||||
optim_wrapper = OptimWrapper(optimizer)
|
||||
optim_wrapper.zero_grad()
|
||||
optimizer.zero_grad.assert_called()
|
||||
|
||||
def test_step(self):
|
||||
optimizer = MagicMock(spec=Optimizer)
|
||||
optimizer.defaults = {
|
||||
} # adjust this line according to what OptimWrapper expects
|
||||
optimizer.param_groups = [{}]
|
||||
optim_wrapper = OptimWrapper(optimizer)
|
||||
optim_wrapper.step()
|
||||
optimizer.step.assert_called()
|
||||
|
@ -237,7 +253,6 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||
model = ToyModel()
|
||||
optimizer = SGD(model.parameters(), lr=0.1)
|
||||
optim_wrapper.load_state_dict(optimizer.state_dict())
|
||||
|
||||
self.assertEqual(optim_wrapper.state_dict(), optimizer.state_dict())
|
||||
|
||||
def test_param_groups(self):
|
||||
|
@ -447,6 +462,9 @@ class TestAmpOptimWrapper(TestCase):
|
|||
if dtype == 'bfloat16' and not bf16_supported():
|
||||
raise unittest.SkipTest('bfloat16 not supported by device')
|
||||
optimizer = MagicMock(spec=Optimizer)
|
||||
optimizer.defaults = {
|
||||
} # adjust this line according to what OptimWrapper expects
|
||||
optimizer.param_groups = [{}]
|
||||
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer, dtype=dtype)
|
||||
amp_optim_wrapper.loss_scaler = MagicMock()
|
||||
amp_optim_wrapper.step()
|
||||
|
@ -504,7 +522,6 @@ class TestAmpOptimWrapper(TestCase):
|
|||
# Test load from optimizer
|
||||
optimizer = SGD(self.model.parameters(), lr=0.1)
|
||||
amp_optim_wrapper.load_state_dict(optimizer.state_dict())
|
||||
|
||||
self.assertDictEqual(optimizer.state_dict(),
|
||||
amp_optim_wrapper.optimizer.state_dict())
|
||||
# Test load from optim_wrapper
|
||||
|
|
|
@ -167,6 +167,9 @@ class TestParameterScheduler(TestCase):
|
|||
self.optimizer, param_name='lr', step_size=3, gamma=0.1)
|
||||
for epoch in range(epochs):
|
||||
result = scheduler.get_last_value()
|
||||
if isinstance(scheduler.optimizer, OptimWrapper) \
|
||||
and scheduler.optimizer.base_param_settings is not None:
|
||||
result.pop()
|
||||
self.optimizer.step()
|
||||
scheduler.step()
|
||||
target = [t[epoch] for t in targets]
|
||||
|
|
Loading…
Reference in New Issue