mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* [Docs] Refine registry documentation (#186) * [Docs] Refine registry documentation * reslove comments * minor refinement * Refine Visualizer docs (#177) * Refine Visualizer docs * update * update * update featmap * update docs * update visualizer docs * [Refactor] Refine LoggerHook (#155) * rename global accessible and intergration get_sintance and create_instance * move ManagerMixin to utils * fix as docstring and seporate get_instance to get_instance and get_current_instance * fix lint * fix docstring, rename and move test_global_meta * rename LogBuffer to HistoryBuffer, rename MessageHub methods, MessageHub support resume * refine MMLogger timestamp, update unit test * MMLogger add logger_name arguments * Fix docstring * Add LogProcessor and some unit test * update unit test * complete LogProcessor unit test * refine LoggerHook * solve circle import * change default logger_name to mmengine * refactor eta * Fix docstring comment and unitt test * Fix with runner * fix docstring fix docstring * fix docstring * Add by_epoch attribute to LoggerHook and fix docstring * Please mypy and fix comment * remove \ in MMLogger * Fix lint * roll back pre-commit-hook * Fix hook unit test * Fix comments * remove \t in log and add docstring * Fix as comment * should not accept other arguments if corresponding instance has been created * fix logging ddp file saving * fix logging ddp file saving * move log processor to logging * move log processor to logging * remove current datalaoder * fix docstring * fix unit test * add learing rate in messagehub * Support output training/validation/testing message after iterations/epochs * fix docstring * Fix IterBasedRunner log string * Fix IterBasedRunner log string * Support parse validation loss in log processor * [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR (#188) * [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR * min_lr -> eta_min, refined docstr Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com> Co-authored-by: Tong Gao <gaotongxiao@gmail.com>
116 lines
4.3 KiB
Python
116 lines
4.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest.mock import MagicMock, Mock
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from mmengine.hooks import OptimizerHook
|
|
|
|
|
|
class TestOptimizerHook:
|
|
|
|
def test_after_train_iter(self):
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(
|
|
in_channels=1,
|
|
out_channels=2,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
dilation=1)
|
|
self.conv2 = nn.Conv2d(
|
|
in_channels=2,
|
|
out_channels=2,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
dilation=1)
|
|
self.conv3 = nn.Conv2d(
|
|
in_channels=1,
|
|
out_channels=2,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
dilation=1)
|
|
|
|
def forward(self, x):
|
|
x1 = self.conv1(x)
|
|
x2 = self.conv2(x1)
|
|
return x1, x2
|
|
|
|
model = Model()
|
|
x = torch.rand(1, 1, 3, 3)
|
|
|
|
dummy_runner = MagicMock()
|
|
dummy_runner.optimizer.zero_grad = Mock(return_value=None)
|
|
dummy_runner.optimizer.step = Mock(return_value=None)
|
|
dummy_runner.model = model
|
|
dummy_runner.outputs = dict()
|
|
|
|
dummy_runner.outputs['num_samples'] = 0
|
|
|
|
class DummyLogger():
|
|
|
|
def __init__(self):
|
|
self.msg = ''
|
|
|
|
def log(self, msg=None, **kwargs):
|
|
self.msg += msg
|
|
|
|
dummy_runner.logger = DummyLogger()
|
|
optimizer_hook = OptimizerHook(
|
|
dict(max_norm=2), detect_anomalous_params=True)
|
|
|
|
dummy_runner.outputs['loss'] = model(x)[0].sum()
|
|
|
|
dummy_runner.outputs['loss'].backward = Mock(
|
|
wraps=dummy_runner.outputs['loss'].backward)
|
|
optimizer_hook.detect_anomalous_parameters = Mock(
|
|
wraps=optimizer_hook.detect_anomalous_parameters)
|
|
optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
|
|
|
|
optimizer_hook.after_train_iter(dummy_runner, 0)
|
|
# assert the parameters of conv2 and conv3 are not in the
|
|
# computational graph which is with x1.sum() as root.
|
|
assert 'conv2.weight' in dummy_runner.logger.msg
|
|
assert 'conv2.bias' in dummy_runner.logger.msg
|
|
assert 'conv3.weight' in dummy_runner.logger.msg
|
|
assert 'conv3.bias' in dummy_runner.logger.msg
|
|
assert 'conv1.weight' not in dummy_runner.logger.msg
|
|
assert 'conv1.bias' not in dummy_runner.logger.msg
|
|
dummy_runner.optimizer.step.assert_called()
|
|
dummy_runner.outputs['loss'].backward.assert_called()
|
|
optimizer_hook.clip_grads.assert_called()
|
|
optimizer_hook.detect_anomalous_parameters.assert_called()
|
|
|
|
dummy_runner.outputs['loss'] = model(x)[1].sum()
|
|
dummy_runner.logger.msg = ''
|
|
optimizer_hook.after_train_iter(dummy_runner, 0)
|
|
# assert the parameters of conv3 are not in the computational graph
|
|
assert 'conv3.weight' in dummy_runner.logger.msg
|
|
assert 'conv3.bias' in dummy_runner.logger.msg
|
|
assert 'conv2.weight' not in dummy_runner.logger.msg
|
|
assert 'conv2.bias' not in dummy_runner.logger.msg
|
|
assert 'conv1.weight' not in dummy_runner.logger.msg
|
|
assert 'conv1.bias' not in dummy_runner.logger.msg
|
|
|
|
# grad_clip is None and detect_anomalous_parameters is False
|
|
optimizer_hook = OptimizerHook(detect_anomalous_params=False)
|
|
optimizer_hook.detect_anomalous_parameters = Mock(
|
|
wraps=optimizer_hook.detect_anomalous_parameters)
|
|
optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
|
|
dummy_runner.outputs['loss'] = model(x)[0].sum()
|
|
dummy_runner.outputs['loss'].backward = Mock(
|
|
wraps=dummy_runner.outputs['loss'].backward)
|
|
|
|
optimizer_hook.after_train_iter(dummy_runner, 0)
|
|
|
|
dummy_runner.optimizer.step.assert_called()
|
|
dummy_runner.outputs['loss'].backward.assert_called()
|
|
optimizer_hook.clip_grads.assert_not_called()
|
|
optimizer_hook.detect_anomalous_parameters.assert_not_called()
|