mmengine/mmengine/hooks/optimizer_hook.py
Mashiro e0d00c5bdd
[Fix] resolve conflict betweem adapt and main. (#198)
* [Docs] Refine registry documentation (#186)

* [Docs] Refine registry documentation

* reslove comments

* minor refinement

* Refine Visualizer docs (#177)

* Refine Visualizer docs

* update

* update

* update featmap

* update docs

* update visualizer docs

* [Refactor] Refine LoggerHook (#155)

* rename global accessible and intergration get_sintance and create_instance

* move ManagerMixin to utils

* fix as docstring and seporate get_instance to get_instance and get_current_instance

* fix lint

* fix docstring, rename and move test_global_meta

* rename LogBuffer to HistoryBuffer, rename MessageHub methods, MessageHub support resume

* refine MMLogger timestamp, update unit test

* MMLogger add logger_name arguments

* Fix docstring

* Add LogProcessor and some unit test

* update unit test

* complete LogProcessor unit test

* refine LoggerHook

* solve circle import

* change default logger_name to mmengine

* refactor eta

* Fix docstring comment and unitt test

* Fix with runner

* fix docstring

fix docstring

* fix docstring

* Add by_epoch attribute to LoggerHook and fix docstring

* Please mypy and fix comment

* remove \ in MMLogger

* Fix lint

* roll back pre-commit-hook

* Fix hook unit test

* Fix comments

* remove \t in log and add docstring

* Fix as comment

* should not accept other arguments if corresponding instance has been created

* fix logging ddp file saving

* fix logging ddp file saving

* move log processor to logging

* move log processor to logging

* remove current datalaoder

* fix docstring

* fix unit test

* add learing rate in messagehub

* Support output training/validation/testing message after iterations/epochs

* fix docstring

* Fix IterBasedRunner log string

* Fix IterBasedRunner log string

* Support parse validation loss in log processor

* [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR (#188)

* [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR

* min_lr -> eta_min, refined docstr

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>
Co-authored-by: Tong Gao <gaotongxiao@gmail.com>
2022-04-26 00:37:16 +08:00

133 lines
4.9 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import List, Optional, Sequence
import torch
from torch.nn.parameter import Parameter
from torch.nn.utils import clip_grad
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
class OptimizerHook(Hook):
"""A hook contains custom operations for the optimizer.
Args:
grad_clip (dict, optional): A config dict to control the clip_grad.
Defaults to None.
detect_anomalous_params (bool): This option is only used for
debugging which will slow down the training speed.
Detect anomalous parameters that are not included in
the computational graph with ``loss`` as the root.
There are two cases
- Parameters were not used during
forward pass.
- Parameters were not used to produce
loss.
Defaults to False.
"""
priority = 'HIGH'
def __init__(self,
grad_clip: Optional[dict] = None,
detect_anomalous_params: bool = False) -> None:
self.grad_clip = grad_clip
self.detect_anomalous_params = detect_anomalous_params
def clip_grads(self, params: List[Parameter]) -> Optional[torch.Tensor]:
"""Clip the gradients of parameters.
Args:
params (list[Parameter]): Model's parameters.
Returns:
Optional[torch.Tensor]: Total norm of the parameters if there is
at least one param requiring gradient, else None.
"""
params = list(
filter(lambda p: p.requires_grad and p.grad is not None, params))
if len(params) > 0:
return clip_grad.clip_grad_norm_(params, **self.grad_clip)
return None
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""All operations need to be finished after each training iteration.
This function will finish following 3 operations:
- Detect any anomalous parameters which are not included in the
training graph. (optional)
- Compute the gradient of model parameters.
- Clip the gradients of each parameter. (optional)
- Update model parameters with gradients.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
In order to keep this interface consistent with other hooks,
we keep ``data_batch`` here. Defaults to None.
outputs (dict, optional): Outputs from model.
In order to keep this interface consistent with other hooks,
we keep ``outputs`` here. Defaults to None.
"""
runner.optimizer.zero_grad()
runner.message_hub.update_scalar(
'train/lr', runner.optimizer.param_groups[0]['lr'])
if self.detect_anomalous_params:
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
runner.outputs['loss'].backward()
if self.grad_clip is not None:
grad_norm = self.clip_grads(runner.model.parameters())
if grad_norm is not None:
# Add grad norm to the logger
runner.log_buffer.update({'grad_norm': float(grad_norm)},
runner.outputs['num_samples'])
runner.optimizer.step()
def detect_anomalous_parameters(self, loss: torch.Tensor, runner) -> None:
"""Detect anomalous parameters that are not included in the graph.
Args:
loss (torch.Tensor): The loss of current iteration.
runner (Runner): The runner of the training process.
"""
logger = runner.logger
parameters_in_graph = set()
visited = set()
def traverse(grad_fn):
if grad_fn is None:
return
if grad_fn not in visited:
visited.add(grad_fn)
if hasattr(grad_fn, 'variable'):
parameters_in_graph.add(grad_fn.variable)
parents = grad_fn.next_functions
if parents is not None:
for parent in parents:
grad_fn = parent[0]
traverse(grad_fn)
traverse(loss.grad_fn)
for n, p in runner.model.named_parameters():
if p not in parameters_in_graph and p.requires_grad:
logger.log(
level=logging.ERROR,
msg=f'{n} with shape {p.size()} is not '
f'in the computational graph \n')