mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Learning rate in log can show the base learning rate of optimizer (#1019)
This commit is contained in:
parent
eb2dc671e9
commit
94e7a3bb57
@ -133,7 +133,7 @@ class AmpOptimWrapper(OptimWrapper):
|
|||||||
:attr:`optimizer`.
|
:attr:`optimizer`.
|
||||||
"""
|
"""
|
||||||
# save state_dict of loss_scaler
|
# save state_dict of loss_scaler
|
||||||
state_dict = self.optimizer.state_dict()
|
state_dict = super().state_dict()
|
||||||
state_dict['loss_scaler'] = self.loss_scaler.state_dict()
|
state_dict['loss_scaler'] = self.loss_scaler.state_dict()
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
@ -151,6 +151,11 @@ class AmpOptimWrapper(OptimWrapper):
|
|||||||
"""
|
"""
|
||||||
if 'loss_scaler' in state_dict:
|
if 'loss_scaler' in state_dict:
|
||||||
self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler'))
|
self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler'))
|
||||||
|
|
||||||
|
if 'base_param_settings' in state_dict:
|
||||||
|
self.base_param_settings = state_dict.pop('base_param_settings')
|
||||||
|
|
||||||
|
# load state_dict of optimizer
|
||||||
self.optimizer.load_state_dict(state_dict)
|
self.optimizer.load_state_dict(state_dict)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -161,6 +161,20 @@ class OptimWrapper:
|
|||||||
# the loss factor will always be the same as `_accumulative_counts`.
|
# the loss factor will always be the same as `_accumulative_counts`.
|
||||||
self._remainder_counts = -1
|
self._remainder_counts = -1
|
||||||
|
|
||||||
|
# The Following code is used to initialize `base_param_settings`.
|
||||||
|
# `base_param_settings` is used to store the parameters that are not
|
||||||
|
# updated by the optimizer.
|
||||||
|
# The `base_param_settings` used for tracking the base learning in the
|
||||||
|
# optimizer. If the optimizer has multiple parameter groups, this
|
||||||
|
# params will not be scaled by the loss factor.
|
||||||
|
if len(optimizer.param_groups) > 1:
|
||||||
|
self.base_param_settings = {
|
||||||
|
'params': torch.tensor([0.0], dtype=torch.float)
|
||||||
|
}
|
||||||
|
self.base_param_settings.update(**self.optimizer.defaults)
|
||||||
|
else:
|
||||||
|
self.base_param_settings = None # type: ignore
|
||||||
|
|
||||||
def update_params(self,
|
def update_params(self,
|
||||||
loss: torch.Tensor,
|
loss: torch.Tensor,
|
||||||
step_kwargs: Optional[Dict] = None,
|
step_kwargs: Optional[Dict] = None,
|
||||||
@ -251,7 +265,10 @@ class OptimWrapper:
|
|||||||
Returns:
|
Returns:
|
||||||
dict: The state dictionary of :attr:`optimizer`.
|
dict: The state dictionary of :attr:`optimizer`.
|
||||||
"""
|
"""
|
||||||
return self.optimizer.state_dict()
|
state_dict = self.optimizer.state_dict()
|
||||||
|
if self.base_param_settings is not None:
|
||||||
|
state_dict['base_param_settings'] = self.base_param_settings
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: dict) -> None:
|
def load_state_dict(self, state_dict: dict) -> None:
|
||||||
"""A wrapper of ``Optimizer.load_state_dict``. load the state dict of
|
"""A wrapper of ``Optimizer.load_state_dict``. load the state dict of
|
||||||
@ -265,6 +282,12 @@ class OptimWrapper:
|
|||||||
Args:
|
Args:
|
||||||
state_dict (dict): The state dictionary of :attr:`optimizer`.
|
state_dict (dict): The state dictionary of :attr:`optimizer`.
|
||||||
"""
|
"""
|
||||||
|
base_param_settings = state_dict.pop('base_param_settings', None)
|
||||||
|
|
||||||
|
if base_param_settings is not None:
|
||||||
|
self.base_param_settings = base_param_settings
|
||||||
|
|
||||||
|
# load state_dict of optimizer
|
||||||
self.optimizer.load_state_dict(state_dict)
|
self.optimizer.load_state_dict(state_dict)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -276,6 +299,9 @@ class OptimWrapper:
|
|||||||
Returns:
|
Returns:
|
||||||
dict: the ``param_groups`` of :attr:`optimizer`.
|
dict: the ``param_groups`` of :attr:`optimizer`.
|
||||||
"""
|
"""
|
||||||
|
if self.base_param_settings is not None:
|
||||||
|
return self.optimizer.param_groups + [self.base_param_settings]
|
||||||
|
else:
|
||||||
return self.optimizer.param_groups
|
return self.optimizer.param_groups
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -295,10 +321,16 @@ class OptimWrapper:
|
|||||||
Provide unified interface to get learning rate of optimizer.
|
Provide unified interface to get learning rate of optimizer.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, List[float]]: Learning rate of the optimizer.
|
Dict[str, List[float]]:
|
||||||
|
param_groups learning rate of the optimizer.
|
||||||
"""
|
"""
|
||||||
lr = [group['lr'] for group in self.param_groups]
|
res = {}
|
||||||
return dict(lr=lr)
|
if self.base_param_settings is not None:
|
||||||
|
res['base_lr'] = [self.base_param_settings['lr']]
|
||||||
|
|
||||||
|
res['lr'] = [group['lr'] for group in self.optimizer.param_groups]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
def get_momentum(self) -> Dict[str, List[float]]:
|
def get_momentum(self) -> Dict[str, List[float]]:
|
||||||
"""Get the momentum of the optimizer.
|
"""Get the momentum of the optimizer.
|
||||||
@ -309,7 +341,7 @@ class OptimWrapper:
|
|||||||
Dict[str, List[float]]: Momentum of the optimizer.
|
Dict[str, List[float]]: Momentum of the optimizer.
|
||||||
"""
|
"""
|
||||||
momentum = []
|
momentum = []
|
||||||
for group in self.param_groups:
|
for group in self.optimizer.param_groups:
|
||||||
# Get momentum of SGD.
|
# Get momentum of SGD.
|
||||||
if 'momentum' in group.keys():
|
if 'momentum' in group.keys():
|
||||||
momentum.append(group['momentum'])
|
momentum.append(group['momentum'])
|
||||||
|
@ -115,7 +115,10 @@ class OptimWrapperDict(OptimWrapper):
|
|||||||
"""
|
"""
|
||||||
lr_dict = dict()
|
lr_dict = dict()
|
||||||
for name, optim_wrapper in self.optim_wrappers.items():
|
for name, optim_wrapper in self.optim_wrappers.items():
|
||||||
lr_dict[f'{name}.lr'] = optim_wrapper.get_lr()['lr']
|
inner_lr_dict = optim_wrapper.get_lr()
|
||||||
|
if 'base_lr' in inner_lr_dict:
|
||||||
|
lr_dict[f'{name}.base_lr'] = inner_lr_dict['base_lr']
|
||||||
|
lr_dict[f'{name}.lr'] = inner_lr_dict['lr']
|
||||||
return lr_dict
|
return lr_dict
|
||||||
|
|
||||||
def get_momentum(self) -> Dict[str, List[float]]:
|
def get_momentum(self) -> Dict[str, List[float]]:
|
||||||
|
@ -1409,11 +1409,28 @@ class ReduceOnPlateauParamScheduler(_ParamScheduler):
|
|||||||
raise ValueError('Factor should be < 1.0.')
|
raise ValueError('Factor should be < 1.0.')
|
||||||
self.factor = factor
|
self.factor = factor
|
||||||
|
|
||||||
|
# This code snippet handles compatibility with the optimizer wrapper.
|
||||||
|
# The optimizer wrapper includes an additional parameter to record the
|
||||||
|
# base learning rate (lr) which is not affected by the paramwise_cfg.
|
||||||
|
# By retrieving the base lr, we can obtain the actual base lr that
|
||||||
|
# reflects the learning progress.
|
||||||
|
if isinstance(optimizer, OptimWrapper):
|
||||||
|
raw_optimizer = optimizer.optimizer
|
||||||
|
else:
|
||||||
|
raw_optimizer = optimizer
|
||||||
|
|
||||||
if isinstance(min_value, (list, tuple)):
|
if isinstance(min_value, (list, tuple)):
|
||||||
if len(min_value) != len(optimizer.param_groups):
|
if len(min_value) != len(raw_optimizer.param_groups):
|
||||||
raise ValueError('expected {} min_lrs, got {}'.format(
|
raise ValueError('expected {} min_lrs, got {}'.format(
|
||||||
len(optimizer.param_groups), len(min_value)))
|
len(raw_optimizer.param_groups), len(min_value)))
|
||||||
self.min_values = list(min_value)
|
self.min_values = list(min_value)
|
||||||
|
# Consider the `min_value` of the last param_groups
|
||||||
|
# as the base setting. And we only add this value when
|
||||||
|
# the optimizer is OptimWrapper.
|
||||||
|
if isinstance(optimizer, OptimWrapper) and \
|
||||||
|
optimizer.base_param_settings is not None:
|
||||||
|
self.min_values.append(self.min_values[-1])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.min_values = [min_value] * len( # type: ignore
|
self.min_values = [min_value] * len( # type: ignore
|
||||||
optimizer.param_groups)
|
optimizer.param_groups)
|
||||||
|
@ -14,7 +14,8 @@ from torch.optim import SGD, Adam, Optimizer
|
|||||||
|
|
||||||
from mmengine.dist import all_gather
|
from mmengine.dist import all_gather
|
||||||
from mmengine.logging import MessageHub, MMLogger
|
from mmengine.logging import MessageHub, MMLogger
|
||||||
from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper
|
from mmengine.optim import (AmpOptimWrapper, ApexOptimWrapper,
|
||||||
|
DefaultOptimWrapperConstructor, OptimWrapper)
|
||||||
from mmengine.testing import assert_allclose
|
from mmengine.testing import assert_allclose
|
||||||
from mmengine.testing._internal import MultiProcessTestCase
|
from mmengine.testing._internal import MultiProcessTestCase
|
||||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||||
@ -174,6 +175,15 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||||||
optim = SGD(model.parameters(), lr=0.1)
|
optim = SGD(model.parameters(), lr=0.1)
|
||||||
optim_wrapper = OptimWrapper(optim)
|
optim_wrapper = OptimWrapper(optim)
|
||||||
self.assertEqual(optim_wrapper.get_lr(), dict(lr=[0.1]))
|
self.assertEqual(optim_wrapper.get_lr(), dict(lr=[0.1]))
|
||||||
|
model = ToyModel()
|
||||||
|
optimizer_cfg = dict(
|
||||||
|
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.1))
|
||||||
|
paramwise_cfg = dict(custom_keys={'conv1.weight': dict(lr_mult=0.1)})
|
||||||
|
optim_constructor = DefaultOptimWrapperConstructor(
|
||||||
|
optimizer_cfg, paramwise_cfg)
|
||||||
|
optim_wrapper = optim_constructor(model)
|
||||||
|
self.assertEqual(optim_wrapper.get_lr(),
|
||||||
|
dict(base_lr=[0.1], lr=[0.1 * 0.1] + [0.1] * 5))
|
||||||
|
|
||||||
def test_get_momentum(self):
|
def test_get_momentum(self):
|
||||||
# Get momentum from SGD
|
# Get momentum from SGD
|
||||||
@ -194,12 +204,18 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||||||
|
|
||||||
def test_zero_grad(self):
|
def test_zero_grad(self):
|
||||||
optimizer = MagicMock(spec=Optimizer)
|
optimizer = MagicMock(spec=Optimizer)
|
||||||
|
optimizer.defaults = {
|
||||||
|
} # adjust this line according to what OptimWrapper expects
|
||||||
|
optimizer.param_groups = [{}]
|
||||||
optim_wrapper = OptimWrapper(optimizer)
|
optim_wrapper = OptimWrapper(optimizer)
|
||||||
optim_wrapper.zero_grad()
|
optim_wrapper.zero_grad()
|
||||||
optimizer.zero_grad.assert_called()
|
optimizer.zero_grad.assert_called()
|
||||||
|
|
||||||
def test_step(self):
|
def test_step(self):
|
||||||
optimizer = MagicMock(spec=Optimizer)
|
optimizer = MagicMock(spec=Optimizer)
|
||||||
|
optimizer.defaults = {
|
||||||
|
} # adjust this line according to what OptimWrapper expects
|
||||||
|
optimizer.param_groups = [{}]
|
||||||
optim_wrapper = OptimWrapper(optimizer)
|
optim_wrapper = OptimWrapper(optimizer)
|
||||||
optim_wrapper.step()
|
optim_wrapper.step()
|
||||||
optimizer.step.assert_called()
|
optimizer.step.assert_called()
|
||||||
@ -237,7 +253,6 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||||||
model = ToyModel()
|
model = ToyModel()
|
||||||
optimizer = SGD(model.parameters(), lr=0.1)
|
optimizer = SGD(model.parameters(), lr=0.1)
|
||||||
optim_wrapper.load_state_dict(optimizer.state_dict())
|
optim_wrapper.load_state_dict(optimizer.state_dict())
|
||||||
|
|
||||||
self.assertEqual(optim_wrapper.state_dict(), optimizer.state_dict())
|
self.assertEqual(optim_wrapper.state_dict(), optimizer.state_dict())
|
||||||
|
|
||||||
def test_param_groups(self):
|
def test_param_groups(self):
|
||||||
@ -447,6 +462,9 @@ class TestAmpOptimWrapper(TestCase):
|
|||||||
if dtype == 'bfloat16' and not bf16_supported():
|
if dtype == 'bfloat16' and not bf16_supported():
|
||||||
raise unittest.SkipTest('bfloat16 not supported by device')
|
raise unittest.SkipTest('bfloat16 not supported by device')
|
||||||
optimizer = MagicMock(spec=Optimizer)
|
optimizer = MagicMock(spec=Optimizer)
|
||||||
|
optimizer.defaults = {
|
||||||
|
} # adjust this line according to what OptimWrapper expects
|
||||||
|
optimizer.param_groups = [{}]
|
||||||
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer, dtype=dtype)
|
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer, dtype=dtype)
|
||||||
amp_optim_wrapper.loss_scaler = MagicMock()
|
amp_optim_wrapper.loss_scaler = MagicMock()
|
||||||
amp_optim_wrapper.step()
|
amp_optim_wrapper.step()
|
||||||
@ -504,7 +522,6 @@ class TestAmpOptimWrapper(TestCase):
|
|||||||
# Test load from optimizer
|
# Test load from optimizer
|
||||||
optimizer = SGD(self.model.parameters(), lr=0.1)
|
optimizer = SGD(self.model.parameters(), lr=0.1)
|
||||||
amp_optim_wrapper.load_state_dict(optimizer.state_dict())
|
amp_optim_wrapper.load_state_dict(optimizer.state_dict())
|
||||||
|
|
||||||
self.assertDictEqual(optimizer.state_dict(),
|
self.assertDictEqual(optimizer.state_dict(),
|
||||||
amp_optim_wrapper.optimizer.state_dict())
|
amp_optim_wrapper.optimizer.state_dict())
|
||||||
# Test load from optim_wrapper
|
# Test load from optim_wrapper
|
||||||
|
@ -167,6 +167,9 @@ class TestParameterScheduler(TestCase):
|
|||||||
self.optimizer, param_name='lr', step_size=3, gamma=0.1)
|
self.optimizer, param_name='lr', step_size=3, gamma=0.1)
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
result = scheduler.get_last_value()
|
result = scheduler.get_last_value()
|
||||||
|
if isinstance(scheduler.optimizer, OptimWrapper) \
|
||||||
|
and scheduler.optimizer.base_param_settings is not None:
|
||||||
|
result.pop()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
target = [t[epoch] for t in targets]
|
target = [t[epoch] for t in targets]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user