From 6073d9ebd85db774671278322555a3b175ddd54d Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Tue, 18 Oct 2022 18:02:46 +0800 Subject: [PATCH] [Enhance] add documents for `clip_grad` , and support clip grad by value. (#513) * [Enhance] add documents for , and support clip grad by value * refine docstring * fix as comment * Fix as comment * minor refine * minor refine * remove error comment for clip grad * refine docstring --- docs/zh_cn/tutorials/optim_wrapper.md | 12 ++++ mmengine/optim/optimizer/optimizer_wrapper.py | 57 ++++++++++++++++--- .../test_optimizer/test_optimizer_wrapper.py | 10 ++++ 3 files changed, 72 insertions(+), 7 deletions(-) diff --git a/docs/zh_cn/tutorials/optim_wrapper.md b/docs/zh_cn/tutorials/optim_wrapper.md index 02ff9be7..d5dfe0e5 100644 --- a/docs/zh_cn/tutorials/optim_wrapper.md +++ b/docs/zh_cn/tutorials/optim_wrapper.md @@ -132,6 +132,18 @@ for idx, (input, target) in enumerate(zip(inputs, targets)): optim_wrapper.zero_grad() ``` +我们同样可以为优化器封装配置梯度裁减策略: + +```python +# 基于 torch.nn.utils.clip_grad_norm_ 对梯度进行裁减 +optim_wrapper = AmpOptimWrapper( + optimizer=optimizer, clip_grad=dict(max_norm=1)) + +# 基于 torch.nn.utils.clip_grad_value_ 对梯度进行裁减 +optim_wrapper = AmpOptimWrapper( + optimizer=optimizer, clip_grad=dict(clip_value=0.2)) +``` + ### 获取学习率/动量: 优化器封装提供了 `get_lr` 和 `get_momentum` 接口用于获取优化器的一个参数组的学习率 diff --git a/mmengine/optim/optimizer/optimizer_wrapper.py b/mmengine/optim/optimizer/optimizer_wrapper.py index 767f20f1..58dbc051 100644 --- a/mmengine/optim/optimizer/optimizer_wrapper.py +++ b/mmengine/optim/optimizer/optimizer_wrapper.py @@ -5,7 +5,6 @@ from typing import Dict, List, Optional import torch import torch.nn as nn -from torch.nn.utils import clip_grad from torch.optim import Optimizer from mmengine.logging import MessageHub, print_log @@ -32,7 +31,27 @@ class OptimWrapper: gradients. The parameters will be updated per ``accumulative_counts``. clip_grad (dict, optional): If ``clip_grad`` is not None, it will be - the arguments of ``torch.nn.utils.clip_grad``. + the arguments of :func:`torch.nn.utils.clip_grad_norm_` or + :func:`torch.nn.utils.clip_grad_value_`. ``clip_grad`` should be a + dict, and the keys could be set as follows: + + If the key ``type`` is not set, or ``type`` is "norm", + the accepted keys are as follows: + + - max_norm (float or int): Max norm of the gradients. + - norm_type (float or int): Type of the used p-norm. Can be + ``'inf'`` for infinity norm. + - error_if_nonfinite (bool): If True, an error is thrown if + the total norm of the gradients from :attr:`parameters` is + ``nan``, ``inf``, or ``-inf``. Default: False (will switch + to True in the future) + + If the key ``type`` is set to "value", the accepted keys are as + follows: + + - clip_value (float or int): maximum allowed value of the + gradients. The gradients are clipped in the range + ``(-clip_value, +clip_value)``. Note: If ``accumulative_counts`` is larger than 1, perform @@ -49,11 +68,18 @@ class OptimWrapper: ``_inner_count += 1`` is automatically performed. Examples: - >>> # Config sample of OptimWrapper. + >>> # Config sample of OptimWrapper and enable clipping gradient by + >>> # norm. >>> optim_wrapper_cfg = dict( >>> type='OptimWrapper', >>> _accumulative_counts=1, >>> clip_grad=dict(max_norm=0.2)) + >>> # Config sample of OptimWrapper and enable clipping gradient by + >>> # value. + >>> optim_wrapper_cfg = dict( + >>> type='OptimWrapper', + >>> _accumulative_counts=1, + >>> clip_grad=dict(type='value', clip_value=0.2)) >>> # Use OptimWrapper to update model. >>> import torch.nn as nn >>> import torch @@ -105,7 +131,22 @@ class OptimWrapper: # clip_grad_kwargs should not be non-empty dict. assert isinstance(clip_grad, dict) and clip_grad, ( 'If `clip_grad` is not None, it should be a `dict` ' - 'which is the arguments of `torch.nn.utils.clip_grad`') + 'which is the arguments of `torch.nn.utils.clip_grad_norm_` ' + 'or clip_grad_value_`.') + clip_type = clip_grad.pop('type', 'norm') + if clip_type == 'norm': + self.clip_func = torch.nn.utils.clip_grad_norm_ + self.grad_name = 'grad_norm' + elif clip_type == 'value': + self.clip_func = torch.nn.utils.clip_grad_value_ + self.grad_name = 'grad_value' + else: + raise ValueError('type of clip_grad should be "norm" or ' + f'"value" but got {clip_type}') + assert clip_grad, ('`clip_grad` should contain other arguments ' + 'besides `type`. The arguments should match ' + 'with the `torch.nn.utils.clip_grad_norm_` or ' + 'clip_grad_value_`') self.clip_grad_kwargs = clip_grad # Used to update `grad_norm` log message. self.message_hub = MessageHub.get_current_instance() @@ -305,9 +346,11 @@ class OptimWrapper: params = list( filter(lambda p: p.requires_grad and p.grad is not None, params)) if len(params) > 0: - grad_norm = clip_grad.clip_grad_norm_(params, - **self.clip_grad_kwargs) - self.message_hub.update_scalar('train/grad_norm', float(grad_norm)) + grad = self.clip_func(params, **self.clip_grad_kwargs) + # `torch.nn.utils.clip_grad_value_` will return None. + if grad is not None: + self.message_hub.update_scalar(f'train/{self.grad_name}', + float(grad)) def initialize_count_status(self, model: nn.Module, init_counts: int, max_counts: int) -> None: diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index 2f3a933c..35984ce3 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -191,6 +191,7 @@ class TestOptimWrapper(MultiProcessTestCase): # in the future). @pytest.mark.skipif(True, reason='Solved in the future') def test_clip_grads(self): + # Test `clip_grad` with `clip_norm_` optim_wrapper = OptimWrapper( self.optimizer, clip_grad=dict(max_norm=35)) loss = self.model(torch.Tensor(1, 1, 1, 1)) @@ -198,6 +199,15 @@ class TestOptimWrapper(MultiProcessTestCase): optim_wrapper._clip_grad() log_scalars = self.message_hub.log_scalars self.assertIn('train/grad_norm', log_scalars) + self.message_hub._log_scalars.clear() + + # Test `clip_grad` with `clip_value_` + optim_wrapper = OptimWrapper( + self.optimizer, clip_grad=dict(type='value', clip_value=0.5)) + loss = self.model(torch.Tensor(1, 1, 1, 1)) + loss.backward() + optim_wrapper._clip_grad() + self.assertNotIn('train/grad_norm', log_scalars) def test_state_dict(self): optim_wrapper = OptimWrapper(self.optimizer)