mmpretrain/tests/test_runtime/test_hooks.py

158 lines
4.7 KiB
Python
Raw Normal View History

import logging
import shutil
import tempfile
import numpy as np
import pytest
import torch
import torch.nn as nn
from mmcv.runner import build_runner
from mmcv.runner.hooks import Hook, IterTimerHook
from torch.utils.data import DataLoader
import mmcls.core # noqa: F401
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
self.conv = nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.linear(x)
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
model = Model()
if multi_optimziers:
optimizer = {
'model1':
torch.optim.SGD(model.linear.parameters(), lr=0.02, momentum=0.95),
'model2':
torch.optim.SGD(model.conv.parameters(), lr=0.01, momentum=0.9),
}
else:
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
tmp_dir = tempfile.mkdtemp()
runner = build_runner(
dict(type=runner_type),
default_args=dict(
model=model,
work_dir=tmp_dir,
optimizer=optimizer,
logger=logging.getLogger(),
max_epochs=max_epochs,
max_iters=max_iters))
return runner
def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):
log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])
runner = _build_demo_runner_without_hook(runner_type, max_epochs,
max_iters, multi_optimziers)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config)
return runner
class ValueCheckHook(Hook):
def __init__(self, check_dict, by_epoch=False):
super().__init__()
self.check_dict = check_dict
self.by_epoch = by_epoch
def after_iter(self, runner):
if self.by_epoch:
return
if runner.iter in self.check_dict:
for attr, target in self.check_dict[runner.iter].items():
value = eval(f'runner.{attr}')
assert np.isclose(value, target), \
(f'The value of `runner.{attr}` is {value}, '
f'not equals to {target}')
def after_epoch(self, runner):
if not self.by_epoch:
return
if runner.epoch in self.check_dict:
for attr, target in self.check_dict[runner.epoch]:
value = eval(f'runner.{attr}')
assert np.isclose(value, target), \
(f'The value of `runner.{attr}` is {value}, '
f'not equals to {target}')
@pytest.mark.parametrize('multi_optimziers', (True, False))
def test_cosine_cooldown_hook(multi_optimziers):
"""xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum LR scheduler
hook_cfg = dict(
type='CosineAnnealingCooldownLrUpdaterHook',
by_epoch=False,
cool_down_time=2,
cool_down_ratio=0.1,
min_lr_ratio=0.1,
warmup_iters=2,
warmup_ratio=0.9)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
runner.register_hook(IterTimerHook())
if multi_optimziers:
check_hook = ValueCheckHook({
0: {
'current_lr()["model1"][0]': 0.02,
'current_lr()["model2"][0]': 0.01,
},
5: {
'current_lr()["model1"][0]': 0.0075558491,
'current_lr()["model2"][0]': 0.0037779246,
},
9: {
'current_lr()["model1"][0]': 0.0002,
'current_lr()["model2"][0]': 0.0001,
}
})
else:
check_hook = ValueCheckHook({
0: {
'current_lr()[0]': 0.02,
},
5: {
'current_lr()[0]': 0.0075558491,
},
9: {
'current_lr()[0]': 0.0002,
}
})
runner.register_hook(check_hook, priority='LOWEST')
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)