mmengine/tests/test_hook/test_optimizer_hook.py
Mashiro e0d00c5bdd
[Fix] resolve conflict betweem adapt and main. (#198)
* [Docs] Refine registry documentation (#186)

* [Docs] Refine registry documentation

* reslove comments

* minor refinement

* Refine Visualizer docs (#177)

* Refine Visualizer docs

* update

* update

* update featmap

* update docs

* update visualizer docs

* [Refactor] Refine LoggerHook (#155)

* rename global accessible and intergration get_sintance and create_instance

* move ManagerMixin to utils

* fix as docstring and seporate get_instance to get_instance and get_current_instance

* fix lint

* fix docstring, rename and move test_global_meta

* rename LogBuffer to HistoryBuffer, rename MessageHub methods, MessageHub support resume

* refine MMLogger timestamp, update unit test

* MMLogger add logger_name arguments

* Fix docstring

* Add LogProcessor and some unit test

* update unit test

* complete LogProcessor unit test

* refine LoggerHook

* solve circle import

* change default logger_name to mmengine

* refactor eta

* Fix docstring comment and unitt test

* Fix with runner

* fix docstring

fix docstring

* fix docstring

* Add by_epoch attribute to LoggerHook and fix docstring

* Please mypy and fix comment

* remove \ in MMLogger

* Fix lint

* roll back pre-commit-hook

* Fix hook unit test

* Fix comments

* remove \t in log and add docstring

* Fix as comment

* should not accept other arguments if corresponding instance has been created

* fix logging ddp file saving

* fix logging ddp file saving

* move log processor to logging

* move log processor to logging

* remove current datalaoder

* fix docstring

* fix unit test

* add learing rate in messagehub

* Support output training/validation/testing message after iterations/epochs

* fix docstring

* Fix IterBasedRunner log string

* Fix IterBasedRunner log string

* Support parse validation loss in log processor

* [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR (#188)

* [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR

* min_lr -> eta_min, refined docstr

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>
Co-authored-by: Tong Gao <gaotongxiao@gmail.com>
2022-04-26 00:37:16 +08:00

116 lines
4.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock, Mock
import torch
from torch import nn
from mmengine.hooks import OptimizerHook
class TestOptimizerHook:
def test_after_train_iter(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
self.conv2 = nn.Conv2d(
in_channels=2,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
self.conv3 = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x1, x2
model = Model()
x = torch.rand(1, 1, 3, 3)
dummy_runner = MagicMock()
dummy_runner.optimizer.zero_grad = Mock(return_value=None)
dummy_runner.optimizer.step = Mock(return_value=None)
dummy_runner.model = model
dummy_runner.outputs = dict()
dummy_runner.outputs['num_samples'] = 0
class DummyLogger():
def __init__(self):
self.msg = ''
def log(self, msg=None, **kwargs):
self.msg += msg
dummy_runner.logger = DummyLogger()
optimizer_hook = OptimizerHook(
dict(max_norm=2), detect_anomalous_params=True)
dummy_runner.outputs['loss'] = model(x)[0].sum()
dummy_runner.outputs['loss'].backward = Mock(
wraps=dummy_runner.outputs['loss'].backward)
optimizer_hook.detect_anomalous_parameters = Mock(
wraps=optimizer_hook.detect_anomalous_parameters)
optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
optimizer_hook.after_train_iter(dummy_runner, 0)
# assert the parameters of conv2 and conv3 are not in the
# computational graph which is with x1.sum() as root.
assert 'conv2.weight' in dummy_runner.logger.msg
assert 'conv2.bias' in dummy_runner.logger.msg
assert 'conv3.weight' in dummy_runner.logger.msg
assert 'conv3.bias' in dummy_runner.logger.msg
assert 'conv1.weight' not in dummy_runner.logger.msg
assert 'conv1.bias' not in dummy_runner.logger.msg
dummy_runner.optimizer.step.assert_called()
dummy_runner.outputs['loss'].backward.assert_called()
optimizer_hook.clip_grads.assert_called()
optimizer_hook.detect_anomalous_parameters.assert_called()
dummy_runner.outputs['loss'] = model(x)[1].sum()
dummy_runner.logger.msg = ''
optimizer_hook.after_train_iter(dummy_runner, 0)
# assert the parameters of conv3 are not in the computational graph
assert 'conv3.weight' in dummy_runner.logger.msg
assert 'conv3.bias' in dummy_runner.logger.msg
assert 'conv2.weight' not in dummy_runner.logger.msg
assert 'conv2.bias' not in dummy_runner.logger.msg
assert 'conv1.weight' not in dummy_runner.logger.msg
assert 'conv1.bias' not in dummy_runner.logger.msg
# grad_clip is None and detect_anomalous_parameters is False
optimizer_hook = OptimizerHook(detect_anomalous_params=False)
optimizer_hook.detect_anomalous_parameters = Mock(
wraps=optimizer_hook.detect_anomalous_parameters)
optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
dummy_runner.outputs['loss'] = model(x)[0].sum()
dummy_runner.outputs['loss'].backward = Mock(
wraps=dummy_runner.outputs['loss'].backward)
optimizer_hook.after_train_iter(dummy_runner, 0)
dummy_runner.optimizer.step.assert_called()
dummy_runner.outputs['loss'].backward.assert_called()
optimizer_hook.clip_grads.assert_not_called()
optimizer_hook.detect_anomalous_parameters.assert_not_called()