[Enhance] add documents for `clip_grad` , and support clip grad by value. (#513)
* [Enhance] add documents for , and support clip grad by value * refine docstring * fix as comment * Fix as comment * minor refine * minor refine * remove error comment for clip grad * refine docstringpull/634/head
parent
4111cfb511
commit
6073d9ebd8
|
@ -132,6 +132,18 @@ for idx, (input, target) in enumerate(zip(inputs, targets)):
|
|||
optim_wrapper.zero_grad()
|
||||
```
|
||||
|
||||
我们同样可以为优化器封装配置梯度裁减策略:
|
||||
|
||||
```python
|
||||
# 基于 torch.nn.utils.clip_grad_norm_ 对梯度进行裁减
|
||||
optim_wrapper = AmpOptimWrapper(
|
||||
optimizer=optimizer, clip_grad=dict(max_norm=1))
|
||||
|
||||
# 基于 torch.nn.utils.clip_grad_value_ 对梯度进行裁减
|
||||
optim_wrapper = AmpOptimWrapper(
|
||||
optimizer=optimizer, clip_grad=dict(clip_value=0.2))
|
||||
```
|
||||
|
||||
### 获取学习率/动量:
|
||||
|
||||
优化器封装提供了 `get_lr` 和 `get_momentum` 接口用于获取优化器的一个参数组的学习率
|
||||
|
|
|
@ -5,7 +5,6 @@ from typing import Dict, List, Optional
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import clip_grad
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from mmengine.logging import MessageHub, print_log
|
||||
|
@ -32,7 +31,27 @@ class OptimWrapper:
|
|||
gradients. The parameters will be updated per
|
||||
``accumulative_counts``.
|
||||
clip_grad (dict, optional): If ``clip_grad`` is not None, it will be
|
||||
the arguments of ``torch.nn.utils.clip_grad``.
|
||||
the arguments of :func:`torch.nn.utils.clip_grad_norm_` or
|
||||
:func:`torch.nn.utils.clip_grad_value_`. ``clip_grad`` should be a
|
||||
dict, and the keys could be set as follows:
|
||||
|
||||
If the key ``type`` is not set, or ``type`` is "norm",
|
||||
the accepted keys are as follows:
|
||||
|
||||
- max_norm (float or int): Max norm of the gradients.
|
||||
- norm_type (float or int): Type of the used p-norm. Can be
|
||||
``'inf'`` for infinity norm.
|
||||
- error_if_nonfinite (bool): If True, an error is thrown if
|
||||
the total norm of the gradients from :attr:`parameters` is
|
||||
``nan``, ``inf``, or ``-inf``. Default: False (will switch
|
||||
to True in the future)
|
||||
|
||||
If the key ``type`` is set to "value", the accepted keys are as
|
||||
follows:
|
||||
|
||||
- clip_value (float or int): maximum allowed value of the
|
||||
gradients. The gradients are clipped in the range
|
||||
``(-clip_value, +clip_value)``.
|
||||
|
||||
Note:
|
||||
If ``accumulative_counts`` is larger than 1, perform
|
||||
|
@ -49,11 +68,18 @@ class OptimWrapper:
|
|||
``_inner_count += 1`` is automatically performed.
|
||||
|
||||
Examples:
|
||||
>>> # Config sample of OptimWrapper.
|
||||
>>> # Config sample of OptimWrapper and enable clipping gradient by
|
||||
>>> # norm.
|
||||
>>> optim_wrapper_cfg = dict(
|
||||
>>> type='OptimWrapper',
|
||||
>>> _accumulative_counts=1,
|
||||
>>> clip_grad=dict(max_norm=0.2))
|
||||
>>> # Config sample of OptimWrapper and enable clipping gradient by
|
||||
>>> # value.
|
||||
>>> optim_wrapper_cfg = dict(
|
||||
>>> type='OptimWrapper',
|
||||
>>> _accumulative_counts=1,
|
||||
>>> clip_grad=dict(type='value', clip_value=0.2))
|
||||
>>> # Use OptimWrapper to update model.
|
||||
>>> import torch.nn as nn
|
||||
>>> import torch
|
||||
|
@ -105,7 +131,22 @@ class OptimWrapper:
|
|||
# clip_grad_kwargs should not be non-empty dict.
|
||||
assert isinstance(clip_grad, dict) and clip_grad, (
|
||||
'If `clip_grad` is not None, it should be a `dict` '
|
||||
'which is the arguments of `torch.nn.utils.clip_grad`')
|
||||
'which is the arguments of `torch.nn.utils.clip_grad_norm_` '
|
||||
'or clip_grad_value_`.')
|
||||
clip_type = clip_grad.pop('type', 'norm')
|
||||
if clip_type == 'norm':
|
||||
self.clip_func = torch.nn.utils.clip_grad_norm_
|
||||
self.grad_name = 'grad_norm'
|
||||
elif clip_type == 'value':
|
||||
self.clip_func = torch.nn.utils.clip_grad_value_
|
||||
self.grad_name = 'grad_value'
|
||||
else:
|
||||
raise ValueError('type of clip_grad should be "norm" or '
|
||||
f'"value" but got {clip_type}')
|
||||
assert clip_grad, ('`clip_grad` should contain other arguments '
|
||||
'besides `type`. The arguments should match '
|
||||
'with the `torch.nn.utils.clip_grad_norm_` or '
|
||||
'clip_grad_value_`')
|
||||
self.clip_grad_kwargs = clip_grad
|
||||
# Used to update `grad_norm` log message.
|
||||
self.message_hub = MessageHub.get_current_instance()
|
||||
|
@ -305,9 +346,11 @@ class OptimWrapper:
|
|||
params = list(
|
||||
filter(lambda p: p.requires_grad and p.grad is not None, params))
|
||||
if len(params) > 0:
|
||||
grad_norm = clip_grad.clip_grad_norm_(params,
|
||||
**self.clip_grad_kwargs)
|
||||
self.message_hub.update_scalar('train/grad_norm', float(grad_norm))
|
||||
grad = self.clip_func(params, **self.clip_grad_kwargs)
|
||||
# `torch.nn.utils.clip_grad_value_` will return None.
|
||||
if grad is not None:
|
||||
self.message_hub.update_scalar(f'train/{self.grad_name}',
|
||||
float(grad))
|
||||
|
||||
def initialize_count_status(self, model: nn.Module, init_counts: int,
|
||||
max_counts: int) -> None:
|
||||
|
|
|
@ -191,6 +191,7 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||
# in the future).
|
||||
@pytest.mark.skipif(True, reason='Solved in the future')
|
||||
def test_clip_grads(self):
|
||||
# Test `clip_grad` with `clip_norm_`
|
||||
optim_wrapper = OptimWrapper(
|
||||
self.optimizer, clip_grad=dict(max_norm=35))
|
||||
loss = self.model(torch.Tensor(1, 1, 1, 1))
|
||||
|
@ -198,6 +199,15 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||
optim_wrapper._clip_grad()
|
||||
log_scalars = self.message_hub.log_scalars
|
||||
self.assertIn('train/grad_norm', log_scalars)
|
||||
self.message_hub._log_scalars.clear()
|
||||
|
||||
# Test `clip_grad` with `clip_value_`
|
||||
optim_wrapper = OptimWrapper(
|
||||
self.optimizer, clip_grad=dict(type='value', clip_value=0.5))
|
||||
loss = self.model(torch.Tensor(1, 1, 1, 1))
|
||||
loss.backward()
|
||||
optim_wrapper._clip_grad()
|
||||
self.assertNotIn('train/grad_norm', log_scalars)
|
||||
|
||||
def test_state_dict(self):
|
||||
optim_wrapper = OptimWrapper(self.optimizer)
|
||||
|
|
Loading…
Reference in New Issue