mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] add documents for clip_grad
, and support clip grad by value. (#513)
* [Enhance] add documents for , and support clip grad by value * refine docstring * fix as comment * Fix as comment * minor refine * minor refine * remove error comment for clip grad * refine docstring
This commit is contained in:
parent
4111cfb511
commit
6073d9ebd8
@ -132,6 +132,18 @@ for idx, (input, target) in enumerate(zip(inputs, targets)):
|
|||||||
optim_wrapper.zero_grad()
|
optim_wrapper.zero_grad()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
我们同样可以为优化器封装配置梯度裁减策略:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 基于 torch.nn.utils.clip_grad_norm_ 对梯度进行裁减
|
||||||
|
optim_wrapper = AmpOptimWrapper(
|
||||||
|
optimizer=optimizer, clip_grad=dict(max_norm=1))
|
||||||
|
|
||||||
|
# 基于 torch.nn.utils.clip_grad_value_ 对梯度进行裁减
|
||||||
|
optim_wrapper = AmpOptimWrapper(
|
||||||
|
optimizer=optimizer, clip_grad=dict(clip_value=0.2))
|
||||||
|
```
|
||||||
|
|
||||||
### 获取学习率/动量:
|
### 获取学习率/动量:
|
||||||
|
|
||||||
优化器封装提供了 `get_lr` 和 `get_momentum` 接口用于获取优化器的一个参数组的学习率
|
优化器封装提供了 `get_lr` 和 `get_momentum` 接口用于获取优化器的一个参数组的学习率
|
||||||
|
@ -5,7 +5,6 @@ from typing import Dict, List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.utils import clip_grad
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from mmengine.logging import MessageHub, print_log
|
from mmengine.logging import MessageHub, print_log
|
||||||
@ -32,7 +31,27 @@ class OptimWrapper:
|
|||||||
gradients. The parameters will be updated per
|
gradients. The parameters will be updated per
|
||||||
``accumulative_counts``.
|
``accumulative_counts``.
|
||||||
clip_grad (dict, optional): If ``clip_grad`` is not None, it will be
|
clip_grad (dict, optional): If ``clip_grad`` is not None, it will be
|
||||||
the arguments of ``torch.nn.utils.clip_grad``.
|
the arguments of :func:`torch.nn.utils.clip_grad_norm_` or
|
||||||
|
:func:`torch.nn.utils.clip_grad_value_`. ``clip_grad`` should be a
|
||||||
|
dict, and the keys could be set as follows:
|
||||||
|
|
||||||
|
If the key ``type`` is not set, or ``type`` is "norm",
|
||||||
|
the accepted keys are as follows:
|
||||||
|
|
||||||
|
- max_norm (float or int): Max norm of the gradients.
|
||||||
|
- norm_type (float or int): Type of the used p-norm. Can be
|
||||||
|
``'inf'`` for infinity norm.
|
||||||
|
- error_if_nonfinite (bool): If True, an error is thrown if
|
||||||
|
the total norm of the gradients from :attr:`parameters` is
|
||||||
|
``nan``, ``inf``, or ``-inf``. Default: False (will switch
|
||||||
|
to True in the future)
|
||||||
|
|
||||||
|
If the key ``type`` is set to "value", the accepted keys are as
|
||||||
|
follows:
|
||||||
|
|
||||||
|
- clip_value (float or int): maximum allowed value of the
|
||||||
|
gradients. The gradients are clipped in the range
|
||||||
|
``(-clip_value, +clip_value)``.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
If ``accumulative_counts`` is larger than 1, perform
|
If ``accumulative_counts`` is larger than 1, perform
|
||||||
@ -49,11 +68,18 @@ class OptimWrapper:
|
|||||||
``_inner_count += 1`` is automatically performed.
|
``_inner_count += 1`` is automatically performed.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # Config sample of OptimWrapper.
|
>>> # Config sample of OptimWrapper and enable clipping gradient by
|
||||||
|
>>> # norm.
|
||||||
>>> optim_wrapper_cfg = dict(
|
>>> optim_wrapper_cfg = dict(
|
||||||
>>> type='OptimWrapper',
|
>>> type='OptimWrapper',
|
||||||
>>> _accumulative_counts=1,
|
>>> _accumulative_counts=1,
|
||||||
>>> clip_grad=dict(max_norm=0.2))
|
>>> clip_grad=dict(max_norm=0.2))
|
||||||
|
>>> # Config sample of OptimWrapper and enable clipping gradient by
|
||||||
|
>>> # value.
|
||||||
|
>>> optim_wrapper_cfg = dict(
|
||||||
|
>>> type='OptimWrapper',
|
||||||
|
>>> _accumulative_counts=1,
|
||||||
|
>>> clip_grad=dict(type='value', clip_value=0.2))
|
||||||
>>> # Use OptimWrapper to update model.
|
>>> # Use OptimWrapper to update model.
|
||||||
>>> import torch.nn as nn
|
>>> import torch.nn as nn
|
||||||
>>> import torch
|
>>> import torch
|
||||||
@ -105,7 +131,22 @@ class OptimWrapper:
|
|||||||
# clip_grad_kwargs should not be non-empty dict.
|
# clip_grad_kwargs should not be non-empty dict.
|
||||||
assert isinstance(clip_grad, dict) and clip_grad, (
|
assert isinstance(clip_grad, dict) and clip_grad, (
|
||||||
'If `clip_grad` is not None, it should be a `dict` '
|
'If `clip_grad` is not None, it should be a `dict` '
|
||||||
'which is the arguments of `torch.nn.utils.clip_grad`')
|
'which is the arguments of `torch.nn.utils.clip_grad_norm_` '
|
||||||
|
'or clip_grad_value_`.')
|
||||||
|
clip_type = clip_grad.pop('type', 'norm')
|
||||||
|
if clip_type == 'norm':
|
||||||
|
self.clip_func = torch.nn.utils.clip_grad_norm_
|
||||||
|
self.grad_name = 'grad_norm'
|
||||||
|
elif clip_type == 'value':
|
||||||
|
self.clip_func = torch.nn.utils.clip_grad_value_
|
||||||
|
self.grad_name = 'grad_value'
|
||||||
|
else:
|
||||||
|
raise ValueError('type of clip_grad should be "norm" or '
|
||||||
|
f'"value" but got {clip_type}')
|
||||||
|
assert clip_grad, ('`clip_grad` should contain other arguments '
|
||||||
|
'besides `type`. The arguments should match '
|
||||||
|
'with the `torch.nn.utils.clip_grad_norm_` or '
|
||||||
|
'clip_grad_value_`')
|
||||||
self.clip_grad_kwargs = clip_grad
|
self.clip_grad_kwargs = clip_grad
|
||||||
# Used to update `grad_norm` log message.
|
# Used to update `grad_norm` log message.
|
||||||
self.message_hub = MessageHub.get_current_instance()
|
self.message_hub = MessageHub.get_current_instance()
|
||||||
@ -305,9 +346,11 @@ class OptimWrapper:
|
|||||||
params = list(
|
params = list(
|
||||||
filter(lambda p: p.requires_grad and p.grad is not None, params))
|
filter(lambda p: p.requires_grad and p.grad is not None, params))
|
||||||
if len(params) > 0:
|
if len(params) > 0:
|
||||||
grad_norm = clip_grad.clip_grad_norm_(params,
|
grad = self.clip_func(params, **self.clip_grad_kwargs)
|
||||||
**self.clip_grad_kwargs)
|
# `torch.nn.utils.clip_grad_value_` will return None.
|
||||||
self.message_hub.update_scalar('train/grad_norm', float(grad_norm))
|
if grad is not None:
|
||||||
|
self.message_hub.update_scalar(f'train/{self.grad_name}',
|
||||||
|
float(grad))
|
||||||
|
|
||||||
def initialize_count_status(self, model: nn.Module, init_counts: int,
|
def initialize_count_status(self, model: nn.Module, init_counts: int,
|
||||||
max_counts: int) -> None:
|
max_counts: int) -> None:
|
||||||
|
@ -191,6 +191,7 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||||||
# in the future).
|
# in the future).
|
||||||
@pytest.mark.skipif(True, reason='Solved in the future')
|
@pytest.mark.skipif(True, reason='Solved in the future')
|
||||||
def test_clip_grads(self):
|
def test_clip_grads(self):
|
||||||
|
# Test `clip_grad` with `clip_norm_`
|
||||||
optim_wrapper = OptimWrapper(
|
optim_wrapper = OptimWrapper(
|
||||||
self.optimizer, clip_grad=dict(max_norm=35))
|
self.optimizer, clip_grad=dict(max_norm=35))
|
||||||
loss = self.model(torch.Tensor(1, 1, 1, 1))
|
loss = self.model(torch.Tensor(1, 1, 1, 1))
|
||||||
@ -198,6 +199,15 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||||||
optim_wrapper._clip_grad()
|
optim_wrapper._clip_grad()
|
||||||
log_scalars = self.message_hub.log_scalars
|
log_scalars = self.message_hub.log_scalars
|
||||||
self.assertIn('train/grad_norm', log_scalars)
|
self.assertIn('train/grad_norm', log_scalars)
|
||||||
|
self.message_hub._log_scalars.clear()
|
||||||
|
|
||||||
|
# Test `clip_grad` with `clip_value_`
|
||||||
|
optim_wrapper = OptimWrapper(
|
||||||
|
self.optimizer, clip_grad=dict(type='value', clip_value=0.5))
|
||||||
|
loss = self.model(torch.Tensor(1, 1, 1, 1))
|
||||||
|
loss.backward()
|
||||||
|
optim_wrapper._clip_grad()
|
||||||
|
self.assertNotIn('train/grad_norm', log_scalars)
|
||||||
|
|
||||||
def test_state_dict(self):
|
def test_state_dict(self):
|
||||||
optim_wrapper = OptimWrapper(self.optimizer)
|
optim_wrapper = OptimWrapper(self.optimizer)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user