mmcv/tests/test_runner/test_hooks.py

418 lines
13 KiB
Python

"""Tests the hooks with runners.
CommandLine:
pytest tests/test_hooks.py
xdoctest tests/test_hooks.py zero
"""
import logging
import os.path as osp
import shutil
import sys
import tempfile
from unittest.mock import MagicMock, call
import pytest
import torch
import torch.nn as nn
from torch.nn.init import constant_
from torch.utils.data import DataLoader
from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook,
build_runner)
from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook
def test_checkpoint_hook():
"""xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook."""
# test epoch based runner
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner('EpochBasedRunner', max_epochs=1)
runner.meta = dict()
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
runner.register_hook(checkpointhook)
runner.run([loader], [('train', 1)])
assert runner.meta['hook_msgs']['last_ckpt'] == osp.join(
runner.work_dir, 'epoch_1.pth')
shutil.rmtree(runner.work_dir)
# test iter based runner
runner = _build_demo_runner(
'IterBasedRunner', max_iters=1, max_epochs=None)
runner.meta = dict()
checkpointhook = CheckpointHook(interval=1, by_epoch=False)
runner.register_hook(checkpointhook)
runner.run([loader], [('train', 1)])
assert runner.meta['hook_msgs']['last_ckpt'] == osp.join(
runner.work_dir, 'iter_1.pth')
shutil.rmtree(runner.work_dir)
def test_ema_hook():
"""xdoctest -m tests/test_hooks.py test_ema_hook."""
class DemoModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=1,
padding=1,
bias=True)
self._init_weight()
def _init_weight(self):
constant_(self.conv.weight, 0)
constant_(self.conv.bias, 0)
def forward(self, x):
return self.conv(x).sum()
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
loader = DataLoader(torch.ones((1, 1, 1, 1)))
runner = _build_demo_runner()
demo_model = DemoModel()
runner.model = demo_model
emahook = EMAHook(momentum=0.1, interval=2, warm_up=100, resume_from=None)
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
runner.register_hook(emahook, priority='HIGHEST')
runner.register_hook(checkpointhook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth')
contain_ema_buffer = False
for name, value in checkpoint['state_dict'].items():
if 'ema' in name:
contain_ema_buffer = True
assert value.sum() == 0
value.fill_(1)
else:
assert value.sum() == 0
assert contain_ema_buffer
torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth')
work_dir = runner.work_dir
resume_ema_hook = EMAHook(
momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth')
runner = _build_demo_runner(max_epochs=2)
runner.model = demo_model
runner.register_hook(resume_ema_hook, priority='HIGHEST')
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
runner.register_hook(checkpointhook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth')
contain_ema_buffer = False
for name, value in checkpoint['state_dict'].items():
if 'ema' in name:
contain_ema_buffer = True
assert value.sum() == 2
else:
assert value.sum() == 1
assert contain_ema_buffer
shutil.rmtree(runner.work_dir)
shutil.rmtree(work_dir)
def test_pavi_hook():
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner()
runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1)))
hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
assert hasattr(hook, 'writer')
hook.writer.add_scalars.assert_called_with('val', {
'learning_rate': 0.02,
'momentum': 0.95
}, 1)
hook.writer.add_snapshot_file.assert_called_with(
tag=runner.work_dir.split('/')[-1],
snapshot_file_path=osp.join(runner.work_dir, 'epoch_1.pth'),
iteration=1)
def test_sync_buffers_hook():
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner()
runner.register_hook_from_cfg(dict(type='SyncBuffersHook'))
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
def test_momentum_runner_hook():
"""xdoctest -m tests/test_hooks.py test_momentum_runner_hook."""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
# add momentum scheduler
hook_cfg = dict(
type='CyclicMomentumUpdaterHook',
by_epoch=False,
target_ratio=(0.85 / 0.95, 1),
cyclic_times=1,
step_ratio_up=0.4)
runner.register_hook_from_cfg(hook_cfg)
# add momentum LR scheduler
hook_cfg = dict(
type='CyclicLrUpdaterHook',
by_epoch=False,
target_ratio=(10, 1),
cyclic_times=1,
step_ratio_up=0.4)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
calls = [
call('train', {
'learning_rate': 0.01999999999999999,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.2,
'momentum': 0.85
}, 5),
call('train', {
'learning_rate': 0.155,
'momentum': 0.875
}, 7),
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_cosine_runner_hook():
"""xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
# add momentum scheduler
hook_cfg = dict(
type='CosineAnnealingMomentumUpdaterHook',
min_momentum_ratio=0.99 / 0.95,
by_epoch=False,
warmup_iters=2,
warmup_ratio=0.9 / 0.95)
runner.register_hook_from_cfg(hook_cfg)
# add momentum LR scheduler
hook_cfg = dict(
type='CosineAnnealingLrUpdaterHook',
by_epoch=False,
min_lr_ratio=0,
warmup_iters=2,
warmup_ratio=0.9)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
runner.register_hook(IterTimerHook())
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.01,
'momentum': 0.97
}, 6),
call('train', {
'learning_rate': 0.0004894348370484647,
'momentum': 0.9890211303259032
}, 10)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_cosine_restart_lr_update_hook():
"""Test CosineRestartLrUpdaterHook."""
with pytest.raises(AssertionError):
# either `min_lr` or `min_lr_ratio` should be specified
CosineRestartLrUpdaterHook(
by_epoch=False,
periods=[2, 10],
restart_weights=[0.5, 0.5],
min_lr=0.1,
min_lr_ratio=0)
with pytest.raises(AssertionError):
# periods and restart_weights should have the same length
CosineRestartLrUpdaterHook(
by_epoch=False,
periods=[2, 10],
restart_weights=[0.5],
min_lr_ratio=0)
with pytest.raises(ValueError):
# the last cumulative_periods 7 (out of [5, 7]) should >= 10
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
# add cosine restart LR scheduler
hook = CosineRestartLrUpdaterHook(
by_epoch=False,
periods=[5, 2], # cumulative_periods [5, 7 (5 + 2)]
restart_weights=[0.5, 0.5],
min_lr=0.0001)
runner.register_hook(hook)
runner.register_hook(IterTimerHook())
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
# add cosine restart LR scheduler
hook = CosineRestartLrUpdaterHook(
by_epoch=False,
periods=[5, 5],
restart_weights=[0.5, 0.5],
min_lr_ratio=0)
runner.register_hook(hook)
runner.register_hook(IterTimerHook())
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
calls = [
call('train', {
'learning_rate': 0.01,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.01,
'momentum': 0.95
}, 6),
call('train', {
'learning_rate': 0.0009549150281252633,
'momentum': 0.95
}, 10)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('log_model', (True, False))
def test_mlflow_hook(log_model):
sys.modules['mlflow'] = MagicMock()
sys.modules['mlflow.pytorch'] = MagicMock()
runner = _build_demo_runner()
loader = DataLoader(torch.ones((5, 2)))
hook = MlflowLoggerHook(exp_name='test', log_model=log_model)
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
hook.mlflow.set_experiment.assert_called_with('test')
hook.mlflow.log_metrics.assert_called_with(
{
'learning_rate': 0.02,
'momentum': 0.95
}, step=6)
if log_model:
hook.mlflow_pytorch.log_model.assert_called_with(
runner.model, 'models')
else:
assert not hook.mlflow_pytorch.log_model.called
def test_wandb_hook():
sys.modules['wandb'] = MagicMock()
runner = _build_demo_runner()
hook = WandbLoggerHook()
loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
hook.wandb.init.assert_called_with()
hook.wandb.log.assert_called_with({
'learning_rate': 0.02,
'momentum': 0.95
},
step=6,
commit=True)
hook.wandb.join.assert_called_with()
def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
def forward(self, x):
return self.linear(x)
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])
tmp_dir = tempfile.mkdtemp()
runner = build_runner(
dict(type=runner_type),
default_args=dict(
model=model,
work_dir=tmp_dir,
optimizer=optimizer,
logger=logging.getLogger(),
max_epochs=max_epochs,
max_iters=max_iters))
runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config)
return runner