mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature]: Add optimizer hook (#70)
* [Feature]: Add optimizer hook * [Fix]: Update docstring * [Fix]: Add call with in UT
This commit is contained in:
parent
ee95ce2488
commit
63a3af4f8c
@ -1,9 +1,11 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
from .iter_timer_hook import IterTimerHook
|
from .iter_timer_hook import IterTimerHook
|
||||||
from .sampler_seed_hook import DistSamplerSeedHook
|
from .optimizer_hook import OptimizerHook
|
||||||
from .param_scheduler_hook import ParamSchedulerHook
|
from .param_scheduler_hook import ParamSchedulerHook
|
||||||
|
from .sampler_seed_hook import DistSamplerSeedHook
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook'
|
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
||||||
|
'OptimizerHook'
|
||||||
]
|
]
|
||||||
|
130
mmengine/hooks/optimizer_hook.py
Normal file
130
mmengine/hooks/optimizer_hook.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from torch.nn.utils import clip_grad
|
||||||
|
|
||||||
|
from mmengine.data import BaseDataSample
|
||||||
|
from mmengine.registry import HOOKS
|
||||||
|
from .hook import Hook
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class OptimizerHook(Hook):
|
||||||
|
"""A hook contains custom operations for the optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grad_clip (dict, optional): A config dict to control the clip_grad.
|
||||||
|
Defaults to None.
|
||||||
|
detect_anomalous_params (bool): This option is only used for
|
||||||
|
debugging which will slow down the training speed.
|
||||||
|
Detect anomalous parameters that are not included in
|
||||||
|
the computational graph with ``loss`` as the root.
|
||||||
|
There are two cases
|
||||||
|
- Parameters were not used during
|
||||||
|
forward pass.
|
||||||
|
- Parameters were not used to produce
|
||||||
|
loss.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
grad_clip: Optional[dict] = None,
|
||||||
|
detect_anomalous_params: bool = False) -> None:
|
||||||
|
self.grad_clip = grad_clip
|
||||||
|
self.detect_anomalous_params = detect_anomalous_params
|
||||||
|
|
||||||
|
def clip_grads(self, params: List[Parameter]) -> Optional[torch.Tensor]:
|
||||||
|
"""Clip the gradients of parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (list[Parameter]): Model's parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[torch.Tensor]: Total norm of the parameters if there is
|
||||||
|
at least one param requiring gradient, else None.
|
||||||
|
"""
|
||||||
|
params = list(
|
||||||
|
filter(lambda p: p.requires_grad and p.grad is not None, params))
|
||||||
|
if len(params) > 0:
|
||||||
|
return clip_grad.clip_grad_norm_(params, **self.grad_clip)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def after_train_iter(
|
||||||
|
self,
|
||||||
|
runner: object,
|
||||||
|
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||||
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
|
"""All operations need to be finished after each training iteration.
|
||||||
|
|
||||||
|
This function will finish following 3 operations:
|
||||||
|
|
||||||
|
- Detect any anomalous parameters which are not included in the
|
||||||
|
training graph. (optional)
|
||||||
|
|
||||||
|
- Compute the gradient of model parameters.
|
||||||
|
|
||||||
|
- Clip the gradidents of each parameters. (optional)
|
||||||
|
|
||||||
|
- Update model parameters with gradients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (object): The runner of the training process.
|
||||||
|
data_batch (Sequence[BaseDataSample], optional): Data from
|
||||||
|
dataloader. In order to keep this interface consistent with
|
||||||
|
other hooks, we keep ``data_batch`` here. Defaults to None.
|
||||||
|
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||||
|
In order to keep this interface consistent with other hooks,
|
||||||
|
we keep ``outputs`` here. Defaults to None.
|
||||||
|
"""
|
||||||
|
runner.optimizer.zero_grad() # type: ignore
|
||||||
|
if self.detect_anomalous_params:
|
||||||
|
self.detect_anomalous_parameters(
|
||||||
|
runner.outputs['loss'], # type: ignore
|
||||||
|
runner)
|
||||||
|
runner.outputs['loss'].backward() # type: ignore
|
||||||
|
|
||||||
|
if self.grad_clip is not None:
|
||||||
|
grad_norm = self.clip_grads(
|
||||||
|
runner.model.parameters()) # type: ignore
|
||||||
|
if grad_norm is not None:
|
||||||
|
# Add grad norm to the logger
|
||||||
|
runner.log_buffer.update( # type: ignore
|
||||||
|
{'grad_norm': float(grad_norm)},
|
||||||
|
runner.outputs['num_samples']) # type: ignore
|
||||||
|
runner.optimizer.step() # type: ignore
|
||||||
|
|
||||||
|
def detect_anomalous_parameters(self, loss: torch.Tensor,
|
||||||
|
runner: object) -> None:
|
||||||
|
"""Detect anomalous parameters that are not included in the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (torch.Tensor): The loss of current iteration.
|
||||||
|
runner (object): The runner of the training process.
|
||||||
|
"""
|
||||||
|
logger = runner.logger # type: ignore
|
||||||
|
parameters_in_graph = set()
|
||||||
|
visited = set()
|
||||||
|
|
||||||
|
def traverse(grad_fn):
|
||||||
|
if grad_fn is None:
|
||||||
|
return
|
||||||
|
if grad_fn not in visited:
|
||||||
|
visited.add(grad_fn)
|
||||||
|
if hasattr(grad_fn, 'variable'):
|
||||||
|
parameters_in_graph.add(grad_fn.variable)
|
||||||
|
parents = grad_fn.next_functions
|
||||||
|
if parents is not None:
|
||||||
|
for parent in parents:
|
||||||
|
grad_fn = parent[0]
|
||||||
|
traverse(grad_fn)
|
||||||
|
|
||||||
|
traverse(loss.grad_fn)
|
||||||
|
for n, p in runner.model.named_parameters(): # type: ignore
|
||||||
|
if p not in parameters_in_graph and p.requires_grad:
|
||||||
|
logger.log(
|
||||||
|
level=logging.ERROR,
|
||||||
|
msg=f'{n} with shape {p.size()} is not '
|
||||||
|
f'in the computational graph \n')
|
115
tests/test_hook/test_optimizer_hook.py
Normal file
115
tests/test_hook/test_optimizer_hook.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from mmengine.hooks import OptimizerHook
|
||||||
|
|
||||||
|
|
||||||
|
class TestOptimizerHook:
|
||||||
|
|
||||||
|
def test_after_train_iter(self):
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
dilation=1)
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
in_channels=2,
|
||||||
|
out_channels=2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
dilation=1)
|
||||||
|
self.conv3 = nn.Conv2d(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
dilation=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x1 = self.conv1(x)
|
||||||
|
x2 = self.conv2(x1)
|
||||||
|
return x1, x2
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
x = torch.rand(1, 1, 3, 3)
|
||||||
|
|
||||||
|
dummy_runner = Mock()
|
||||||
|
dummy_runner.optimizer.zero_grad = Mock(return_value=None)
|
||||||
|
dummy_runner.optimizer.step = Mock(return_value=None)
|
||||||
|
dummy_runner.model = model
|
||||||
|
dummy_runner.outputs = dict()
|
||||||
|
|
||||||
|
dummy_runner.outputs['num_samples'] = 0
|
||||||
|
|
||||||
|
class DummyLogger():
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.msg = ''
|
||||||
|
|
||||||
|
def log(self, msg=None, **kwargs):
|
||||||
|
self.msg += msg
|
||||||
|
|
||||||
|
dummy_runner.logger = DummyLogger()
|
||||||
|
optimizer_hook = OptimizerHook(
|
||||||
|
dict(max_norm=2), detect_anomalous_params=True)
|
||||||
|
|
||||||
|
dummy_runner.outputs['loss'] = model(x)[0].sum()
|
||||||
|
|
||||||
|
dummy_runner.outputs['loss'].backward = Mock(
|
||||||
|
wraps=dummy_runner.outputs['loss'].backward)
|
||||||
|
optimizer_hook.detect_anomalous_parameters = Mock(
|
||||||
|
wraps=optimizer_hook.detect_anomalous_parameters)
|
||||||
|
optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
|
||||||
|
|
||||||
|
optimizer_hook.after_train_iter(dummy_runner)
|
||||||
|
# assert the parameters of conv2 and conv3 are not in the
|
||||||
|
# computational graph which is with x1.sum() as root.
|
||||||
|
assert 'conv2.weight' in dummy_runner.logger.msg
|
||||||
|
assert 'conv2.bias' in dummy_runner.logger.msg
|
||||||
|
assert 'conv3.weight' in dummy_runner.logger.msg
|
||||||
|
assert 'conv3.bias' in dummy_runner.logger.msg
|
||||||
|
assert 'conv1.weight' not in dummy_runner.logger.msg
|
||||||
|
assert 'conv1.bias' not in dummy_runner.logger.msg
|
||||||
|
dummy_runner.optimizer.step.assert_called()
|
||||||
|
dummy_runner.outputs['loss'].backward.assert_called()
|
||||||
|
optimizer_hook.clip_grads.assert_called()
|
||||||
|
optimizer_hook.detect_anomalous_parameters.assert_called()
|
||||||
|
|
||||||
|
dummy_runner.outputs['loss'] = model(x)[1].sum()
|
||||||
|
dummy_runner.logger.msg = ''
|
||||||
|
optimizer_hook.after_train_iter(dummy_runner)
|
||||||
|
# assert the parameters of conv3 are not in the computational graph
|
||||||
|
assert 'conv3.weight' in dummy_runner.logger.msg
|
||||||
|
assert 'conv3.bias' in dummy_runner.logger.msg
|
||||||
|
assert 'conv2.weight' not in dummy_runner.logger.msg
|
||||||
|
assert 'conv2.bias' not in dummy_runner.logger.msg
|
||||||
|
assert 'conv1.weight' not in dummy_runner.logger.msg
|
||||||
|
assert 'conv1.bias' not in dummy_runner.logger.msg
|
||||||
|
|
||||||
|
# grad_clip is None and detect_anomalous_parameters is False
|
||||||
|
optimizer_hook = OptimizerHook(detect_anomalous_params=False)
|
||||||
|
optimizer_hook.detect_anomalous_parameters = Mock(
|
||||||
|
wraps=optimizer_hook.detect_anomalous_parameters)
|
||||||
|
optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
|
||||||
|
dummy_runner.outputs['loss'] = model(x)[0].sum()
|
||||||
|
dummy_runner.outputs['loss'].backward = Mock(
|
||||||
|
wraps=dummy_runner.outputs['loss'].backward)
|
||||||
|
|
||||||
|
optimizer_hook.after_train_iter(dummy_runner)
|
||||||
|
|
||||||
|
dummy_runner.optimizer.step.assert_called()
|
||||||
|
dummy_runner.outputs['loss'].backward.assert_called()
|
||||||
|
optimizer_hook.clip_grads.assert_not_called()
|
||||||
|
optimizer_hook.detect_anomalous_parameters.assert_not_called()
|
Loading…
x
Reference in New Issue
Block a user